diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go index 8d8e209..f6f643a 100644 --- a/cmd/headscale/cli/preauthkeys.go +++ b/cmd/headscale/cli/preauthkeys.go @@ -3,6 +3,7 @@ package cli import ( "fmt" "strconv" + "strings" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -33,6 +34,8 @@ func init() { Bool("ephemeral", false, "Preauthkey for ephemeral nodes") createPreAuthKeyCmd.Flags(). StringP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (e.g. 30m, 24h)") + createPreAuthKeyCmd.Flags(). + StringSlice("tags", []string{}, "Tags to automatically assign to node") } var preauthkeysCmd = &cobra.Command{ @@ -81,7 +84,16 @@ var listPreAuthKeys = &cobra.Command{ } tableData := pterm.TableData{ - {"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}, + { + "ID", + "Key", + "Reusable", + "Ephemeral", + "Used", + "Expiration", + "Created", + "Tags", + }, } for _, key := range response.PreAuthKeys { expiration := "-" @@ -96,6 +108,15 @@ var listPreAuthKeys = &cobra.Command{ reusable = fmt.Sprintf("%v", key.GetReusable()) } + var aclTags string + + if len(key.AclTags) > 0 { + for _, tag := range key.AclTags { + aclTags += "," + tag + } + aclTags = strings.TrimLeft(aclTags, ",") + } + tableData = append(tableData, []string{ key.GetId(), key.GetKey(), @@ -104,6 +125,7 @@ var listPreAuthKeys = &cobra.Command{ strconv.FormatBool(key.GetUsed()), expiration, key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), + aclTags, }) } @@ -136,6 +158,7 @@ var createPreAuthKeyCmd = &cobra.Command{ reusable, _ := cmd.Flags().GetBool("reusable") ephemeral, _ := cmd.Flags().GetBool("ephemeral") + tags, _ := cmd.Flags().GetStringSlice("tags") log.Trace(). Bool("reusable", reusable). @@ -147,6 +170,7 @@ var createPreAuthKeyCmd = &cobra.Command{ Namespace: namespace, Reusable: reusable, Ephemeral: ephemeral, + AclTags: tags, } durationStr, _ := cmd.Flags().GetString("expiration") diff --git a/db.go b/db.go index f0a0a59..e29256b 100644 --- a/db.go +++ b/db.go @@ -131,6 +131,11 @@ func (h *Headscale) initDB() error { return err } + err = db.AutoMigrate(&PreAuthKeyAclTag{}) + if err != nil { + return err + } + _ = db.Migrator().DropTable("shared_machines") err = db.AutoMigrate(&APIKey{}) diff --git a/grpcv1.go b/grpcv1.go index e3db5dd..620b8fe 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -1,4 +1,4 @@ -//nolint +// nolint package headscale import ( @@ -111,6 +111,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey( request.GetReusable(), request.GetEphemeral(), &expiration, + request.AclTags, ) if err != nil { return nil, err @@ -141,6 +142,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys( request *v1.ListPreAuthKeysRequest, ) (*v1.ListPreAuthKeysResponse, error) { preAuthKeys, err := api.h.ListPreAuthKeys(request.GetNamespace()) + if err != nil { return nil, err } diff --git a/preauth_keys.go b/preauth_keys.go index f120f45..8359d28 100644 --- a/preauth_keys.go +++ b/preauth_keys.go @@ -29,17 +29,26 @@ type PreAuthKey struct { Reusable bool Ephemeral bool `gorm:"default:false"` Used bool `gorm:"default:false"` + AclTags []PreAuthKeyAclTag CreatedAt *time.Time Expiration *time.Time } +// PreAuthKeyAclTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey +type PreAuthKeyAclTag struct { + ID uint64 `gorm:"primary_key"` + PreAuthKeyID uint64 + Tag string +} + // CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it. func (h *Headscale) CreatePreAuthKey( namespaceName string, reusable bool, ephemeral bool, expiration *time.Time, + aclTags []string, ) (*PreAuthKey, error) { namespace, err := h.GetNamespace(namespaceName) if err != nil { @@ -62,8 +71,26 @@ func (h *Headscale) CreatePreAuthKey( Expiration: expiration, } - if err := h.db.Save(&key).Error; err != nil { - return nil, fmt.Errorf("failed to create key in the database: %w", err) + err = h.db.Transaction(func(db *gorm.DB) error { + if err := db.Save(&key).Error; err != nil { + return fmt.Errorf("failed to create key in the database: %w", err) + } + + if len(aclTags) > 0 { + for _, tag := range aclTags { + if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { + return fmt.Errorf( + "failed to create key tag in the database: %w", + err, + ) + } + } + } + return nil + }) + + if err != nil { + return nil, err } return &key, nil @@ -77,7 +104,7 @@ func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) } keys := []PreAuthKey{} - if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: namespace.ID}).Find(&keys).Error; err != nil { + if err := h.db.Preload("Namespace").Preload("AclTags").Where(&PreAuthKey{NamespaceID: namespace.ID}).Find(&keys).Error; err != nil { return nil, err } @@ -101,11 +128,17 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error { - if result := h.db.Unscoped().Delete(pak); result.Error != nil { - return result.Error - } + return h.db.Transaction(func(db *gorm.DB) error { + if result := db.Unscoped().Delete(PreAuthKeyAclTag{PreAuthKeyID: pak.ID}); result.Error != nil { + return result.Error + } - return nil + if result := db.Unscoped().Delete(pak); result.Error != nil { + return result.Error + } + + return nil + }) } // MarkExpirePreAuthKey marks a PreAuthKey as expired. @@ -131,7 +164,7 @@ func (h *Headscale) UsePreAuthKey(k *PreAuthKey) error { // If returns no error and a PreAuthKey, it can be used. func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { pak := PreAuthKey{} - if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is( + if result := h.db.Preload("Namespace").Preload("AclTags").First(&pak, "key = ?", k); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -176,6 +209,7 @@ func (key *PreAuthKey) toProto() *v1.PreAuthKey { Ephemeral: key.Ephemeral, Reusable: key.Reusable, Used: key.Used, + AclTags: make([]string, len(key.AclTags)), } if key.Expiration != nil { @@ -186,5 +220,11 @@ func (key *PreAuthKey) toProto() *v1.PreAuthKey { protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) } + if len(key.AclTags) > 0 { + for idx := range key.AclTags { + protoKey.AclTags[idx] = key.AclTags[0].Tag + } + } + return &protoKey }