Adds grpc/cli support for preauthkey tags

This commit is contained in:
Benjamin George Roberts 2022-08-25 20:03:38 +10:00
parent e27a4db281
commit 791272e408
4 changed files with 81 additions and 10 deletions

View file

@ -3,6 +3,7 @@ package cli
import ( import (
"fmt" "fmt"
"strconv" "strconv"
"strings"
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -33,6 +34,8 @@ func init() {
Bool("ephemeral", false, "Preauthkey for ephemeral nodes") Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
createPreAuthKeyCmd.Flags(). createPreAuthKeyCmd.Flags().
StringP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (e.g. 30m, 24h)") StringP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (e.g. 30m, 24h)")
createPreAuthKeyCmd.Flags().
StringSlice("tags", []string{}, "Tags to automatically assign to node")
} }
var preauthkeysCmd = &cobra.Command{ var preauthkeysCmd = &cobra.Command{
@ -81,7 +84,16 @@ var listPreAuthKeys = &cobra.Command{
} }
tableData := pterm.TableData{ tableData := pterm.TableData{
{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}, {
"ID",
"Key",
"Reusable",
"Ephemeral",
"Used",
"Expiration",
"Created",
"Tags",
},
} }
for _, key := range response.PreAuthKeys { for _, key := range response.PreAuthKeys {
expiration := "-" expiration := "-"
@ -96,6 +108,15 @@ var listPreAuthKeys = &cobra.Command{
reusable = fmt.Sprintf("%v", key.GetReusable()) reusable = fmt.Sprintf("%v", key.GetReusable())
} }
var aclTags string
if len(key.AclTags) > 0 {
for _, tag := range key.AclTags {
aclTags += "," + tag
}
aclTags = strings.TrimLeft(aclTags, ",")
}
tableData = append(tableData, []string{ tableData = append(tableData, []string{
key.GetId(), key.GetId(),
key.GetKey(), key.GetKey(),
@ -104,6 +125,7 @@ var listPreAuthKeys = &cobra.Command{
strconv.FormatBool(key.GetUsed()), strconv.FormatBool(key.GetUsed()),
expiration, expiration,
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
aclTags,
}) })
} }
@ -136,6 +158,7 @@ var createPreAuthKeyCmd = &cobra.Command{
reusable, _ := cmd.Flags().GetBool("reusable") reusable, _ := cmd.Flags().GetBool("reusable")
ephemeral, _ := cmd.Flags().GetBool("ephemeral") ephemeral, _ := cmd.Flags().GetBool("ephemeral")
tags, _ := cmd.Flags().GetStringSlice("tags")
log.Trace(). log.Trace().
Bool("reusable", reusable). Bool("reusable", reusable).
@ -147,6 +170,7 @@ var createPreAuthKeyCmd = &cobra.Command{
Namespace: namespace, Namespace: namespace,
Reusable: reusable, Reusable: reusable,
Ephemeral: ephemeral, Ephemeral: ephemeral,
AclTags: tags,
} }
durationStr, _ := cmd.Flags().GetString("expiration") durationStr, _ := cmd.Flags().GetString("expiration")

5
db.go
View file

@ -131,6 +131,11 @@ func (h *Headscale) initDB() error {
return err return err
} }
err = db.AutoMigrate(&PreAuthKeyAclTag{})
if err != nil {
return err
}
_ = db.Migrator().DropTable("shared_machines") _ = db.Migrator().DropTable("shared_machines")
err = db.AutoMigrate(&APIKey{}) err = db.AutoMigrate(&APIKey{})

View file

@ -1,4 +1,4 @@
//nolint // nolint
package headscale package headscale
import ( import (
@ -111,6 +111,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
request.GetReusable(), request.GetReusable(),
request.GetEphemeral(), request.GetEphemeral(),
&expiration, &expiration,
request.AclTags,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -141,6 +142,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
request *v1.ListPreAuthKeysRequest, request *v1.ListPreAuthKeysRequest,
) (*v1.ListPreAuthKeysResponse, error) { ) (*v1.ListPreAuthKeysResponse, error) {
preAuthKeys, err := api.h.ListPreAuthKeys(request.GetNamespace()) preAuthKeys, err := api.h.ListPreAuthKeys(request.GetNamespace())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -29,17 +29,26 @@ type PreAuthKey struct {
Reusable bool Reusable bool
Ephemeral bool `gorm:"default:false"` Ephemeral bool `gorm:"default:false"`
Used bool `gorm:"default:false"` Used bool `gorm:"default:false"`
AclTags []PreAuthKeyAclTag
CreatedAt *time.Time CreatedAt *time.Time
Expiration *time.Time Expiration *time.Time
} }
// PreAuthKeyAclTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey
type PreAuthKeyAclTag struct {
ID uint64 `gorm:"primary_key"`
PreAuthKeyID uint64
Tag string
}
// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it. // CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it.
func (h *Headscale) CreatePreAuthKey( func (h *Headscale) CreatePreAuthKey(
namespaceName string, namespaceName string,
reusable bool, reusable bool,
ephemeral bool, ephemeral bool,
expiration *time.Time, expiration *time.Time,
aclTags []string,
) (*PreAuthKey, error) { ) (*PreAuthKey, error) {
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
@ -62,8 +71,26 @@ func (h *Headscale) CreatePreAuthKey(
Expiration: expiration, Expiration: expiration,
} }
if err := h.db.Save(&key).Error; err != nil { err = h.db.Transaction(func(db *gorm.DB) error {
return nil, fmt.Errorf("failed to create key in the database: %w", err) if err := db.Save(&key).Error; err != nil {
return fmt.Errorf("failed to create key in the database: %w", err)
}
if len(aclTags) > 0 {
for _, tag := range aclTags {
if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return fmt.Errorf(
"failed to create key tag in the database: %w",
err,
)
}
}
}
return nil
})
if err != nil {
return nil, err
} }
return &key, nil return &key, nil
@ -77,7 +104,7 @@ func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error)
} }
keys := []PreAuthKey{} keys := []PreAuthKey{}
if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: namespace.ID}).Find(&keys).Error; err != nil { if err := h.db.Preload("Namespace").Preload("AclTags").Where(&PreAuthKey{NamespaceID: namespace.ID}).Find(&keys).Error; err != nil {
return nil, err return nil, err
} }
@ -101,11 +128,17 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist. // does not exist.
func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error { func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error {
if result := h.db.Unscoped().Delete(pak); result.Error != nil { return h.db.Transaction(func(db *gorm.DB) error {
return result.Error if result := db.Unscoped().Delete(PreAuthKeyAclTag{PreAuthKeyID: pak.ID}); result.Error != nil {
} return result.Error
}
return nil if result := db.Unscoped().Delete(pak); result.Error != nil {
return result.Error
}
return nil
})
} }
// MarkExpirePreAuthKey marks a PreAuthKey as expired. // MarkExpirePreAuthKey marks a PreAuthKey as expired.
@ -131,7 +164,7 @@ func (h *Headscale) UsePreAuthKey(k *PreAuthKey) error {
// If returns no error and a PreAuthKey, it can be used. // If returns no error and a PreAuthKey, it can be used.
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
pak := PreAuthKey{} pak := PreAuthKey{}
if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is( if result := h.db.Preload("Namespace").Preload("AclTags").First(&pak, "key = ?", k); errors.Is(
result.Error, result.Error,
gorm.ErrRecordNotFound, gorm.ErrRecordNotFound,
) { ) {
@ -176,6 +209,7 @@ func (key *PreAuthKey) toProto() *v1.PreAuthKey {
Ephemeral: key.Ephemeral, Ephemeral: key.Ephemeral,
Reusable: key.Reusable, Reusable: key.Reusable,
Used: key.Used, Used: key.Used,
AclTags: make([]string, len(key.AclTags)),
} }
if key.Expiration != nil { if key.Expiration != nil {
@ -186,5 +220,11 @@ func (key *PreAuthKey) toProto() *v1.PreAuthKey {
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
} }
if len(key.AclTags) > 0 {
for idx := range key.AclTags {
protoKey.AclTags[idx] = key.AclTags[0].Tag
}
}
return &protoKey return &protoKey
} }