From 92ffac625ec7b32d2a82f89146497df67def1a47 Mon Sep 17 00:00:00 2001 From: Adrien Raffin-Caboisse Date: Tue, 22 Feb 2022 12:45:50 +0100 Subject: [PATCH] feat(namespace): add normalization function for namespace --- dns.go | 6 +- machine.go | 6 +- namespaces.go | 57 ++++++++++++++ namespaces_test.go | 61 +++++++++++++++ oidc_test.go | 180 --------------------------------------------- 5 files changed, 120 insertions(+), 190 deletions(-) delete mode 100644 oidc_test.go diff --git a/dns.go b/dns.go index 45e0fae..9a91cff 100644 --- a/dns.go +++ b/dns.go @@ -165,11 +165,7 @@ func getMapResponseDNSConfig( dnsConfig.Domains, fmt.Sprintf( "%s.%s", - strings.ReplaceAll( - machine.Namespace.Name, - "@", - ".", - ), // Replace @ with . for valid domain for machine + machine.Namespace.Name, baseDomain, ), ) diff --git a/machine.go b/machine.go index ee48342..25edd1d 100644 --- a/machine.go +++ b/machine.go @@ -724,11 +724,7 @@ func (machine Machine) toNode( hostname = fmt.Sprintf( "%s.%s.%s", machine.Name, - strings.ReplaceAll( - machine.Namespace.Name, - "@", - ".", - ), // Replace @ with . for valid domain for machine + machine.Namespace.Name, baseDomain, ) } else { diff --git a/namespaces.go b/namespaces.go index bdd440c..8160b03 100644 --- a/namespaces.go +++ b/namespaces.go @@ -2,7 +2,10 @@ package headscale import ( "errors" + "fmt" + "regexp" "strconv" + "strings" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -16,8 +19,11 @@ const ( errNamespaceExists = Error("Namespace already exists") errNamespaceNotFound = Error("Namespace not found") errNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found") + errInvalidNamespaceName = Error("Invalid namespace name") ) +var normalizeNamespaceRegex = regexp.MustCompile("[^a-z0-9-.]+") + // Namespace is the way Headscale implements the concept of users in Tailscale // // At the end of the day, users in Tailscale are some kind of 'bubbles' or namespaces @@ -30,7 +36,12 @@ type Namespace struct { // CreateNamespace creates a new Namespace. Returns error if could not be created // or another namespace already exists. func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { + var err error namespace := Namespace{} + name, err = NormalizeNamespaceName(name) + if err != nil { + return nil, err + } if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil { return nil, errNamespaceExists } @@ -50,6 +61,10 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { // DestroyNamespace destroys a Namespace. Returns error if the Namespace does // not exist or if there are machines associated with it. func (h *Headscale) DestroyNamespace(name string) error { + name, err := NormalizeNamespaceName(name) + if err != nil { + return err + } namespace, err := h.GetNamespace(name) if err != nil { return errNamespaceNotFound @@ -84,10 +99,15 @@ func (h *Headscale) DestroyNamespace(name string) error { // RenameNamespace renames a Namespace. Returns error if the Namespace does // not exist or if another Namespace exists with the new name. func (h *Headscale) RenameNamespace(oldName, newName string) error { + var err error oldNamespace, err := h.GetNamespace(oldName) if err != nil { return err } + newName, err = NormalizeNamespaceName(newName) + if err != nil { + return err + } _, err = h.GetNamespace(newName) if err == nil { return errNamespaceExists @@ -108,6 +128,10 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error { // GetNamespace fetches a namespace by name. func (h *Headscale) GetNamespace(name string) (*Namespace, error) { namespace := Namespace{} + name, err := NormalizeNamespaceName(name) + if err != nil { + return nil, err + } if result := h.db.First(&namespace, "name = ?", name); errors.Is( result.Error, gorm.ErrRecordNotFound, @@ -130,6 +154,10 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) { // ListMachinesInNamespace gets all the nodes in a given namespace. func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) { + name, err := NormalizeNamespaceName(name) + if err != nil { + return nil, err + } namespace, err := h.GetNamespace(name) if err != nil { return nil, err @@ -145,6 +173,10 @@ func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) { // ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace. func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error) { + name, err := NormalizeNamespaceName(name) + if err != nil { + return nil, err + } namespace, err := h.GetNamespace(name) if err != nil { return nil, err @@ -170,6 +202,10 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error // SetMachineNamespace assigns a Machine to a namespace. func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error { + namespaceName, err := NormalizeNamespaceName(namespaceName) + if err != nil { + return err + } namespace, err := h.GetNamespace(namespaceName) if err != nil { return err @@ -233,3 +269,24 @@ func (n *Namespace) toProto() *v1.Namespace { CreatedAt: timestamppb.New(n.CreatedAt), } } + +// NormalizeNamespaceName will replace forbidden chars in namespace +// it can also return an error if the namespace doesn't respect RFC 952 and 1123 +func NormalizeNamespaceName(name string) (string, error) { + name = strings.ToLower(name) + name = strings.ReplaceAll(name, "@", ".") + name = strings.ReplaceAll(name, "'", "") + name = normalizeNamespaceRegex.ReplaceAllString(name, "-") + + for _, elt := range strings.Split(name, ".") { + if len(elt) > 63 { + return "", fmt.Errorf( + "label %v is more than 63 chars: %w", + elt, + errInvalidNamespaceName, + ) + } + } + + return name, nil +} diff --git a/namespaces_test.go b/namespaces_test.go index d07deb9..cf2b323 100644 --- a/namespaces_test.go +++ b/namespaces_test.go @@ -1,6 +1,8 @@ package headscale import ( + "testing" + "github.com/rs/zerolog/log" "gopkg.in/check.v1" "gorm.io/gorm" @@ -239,3 +241,62 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { } c.Assert(found, check.Equals, true) } + +func TestNormalizeNamespaceName(t *testing.T) { + type args struct { + name string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "normalize simple name", + args: args{name: "normalize-simple.name"}, + want: "normalize-simple.name", + wantErr: false, + }, + { + name: "normalize an email", + args: args{name: "foo.bar@example.com"}, + want: "foo.bar.example.com", + wantErr: false, + }, + { + name: "normalize complex email", + args: args{name: "foo.bar+complex-email@example.com"}, + want: "foo.bar-complex-email.example.com", + wantErr: false, + }, + { + name: "namespace name with space", + args: args{name: "name space"}, + want: "name-space", + wantErr: false, + }, + { + name: "namespace with quote", + args: args{name: "Jamie's iPhone 5"}, + want: "jamies-iphone-5", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NormalizeNamespaceName(tt.args.name) + if (err != nil) != tt.wantErr { + t.Errorf( + "NormalizeNamespaceName() error = %v, wantErr %v", + err, + tt.wantErr, + ) + return + } + if got != tt.want { + t.Errorf("NormalizeNamespaceName() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/oidc_test.go b/oidc_test.go deleted file mode 100644 index d50027a..0000000 --- a/oidc_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package headscale - -import ( - "sync" - "testing" - - "github.com/coreos/go-oidc/v3/oidc" - "github.com/patrickmn/go-cache" - "golang.org/x/oauth2" - "gorm.io/gorm" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -func TestHeadscale_getNamespaceFromEmail(t *testing.T) { - type fields struct { - cfg Config - db *gorm.DB - dbString string - dbType string - dbDebug bool - privateKey *key.MachinePrivate - aclPolicy *ACLPolicy - aclRules []tailcfg.FilterRule - lastStateChange sync.Map - oidcProvider *oidc.Provider - oauth2Config *oauth2.Config - oidcStateCache *cache.Cache - } - type args struct { - email string - } - tests := []struct { - name string - fields fields - args args - want string - want1 bool - }{ - { - name: "match all", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - ".*": "space", - }, - }, - }, - }, - args: args{ - email: "test@example.no", - }, - want: "space", - want1: true, - }, - { - name: "match user", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - "specific@user\\.no": "user-namespace", - }, - }, - }, - }, - args: args{ - email: "specific@user.no", - }, - want: "user-namespace", - want1: true, - }, - { - name: "match domain", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - ".*@example\\.no": "example", - }, - }, - }, - }, - args: args{ - email: "test@example.no", - }, - want: "example", - want1: true, - }, - { - name: "multi match domain", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - ".*@example\\.no": "exammple", - ".*@gmail\\.com": "gmail", - }, - }, - }, - }, - args: args{ - email: "someuser@gmail.com", - }, - want: "gmail", - want1: true, - }, - { - name: "no match domain", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - ".*@dontknow.no": "never", - }, - }, - }, - }, - args: args{ - email: "test@wedontknow.no", - }, - want: "", - want1: false, - }, - { - name: "multi no match domain", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - ".*@dontknow.no": "never", - ".*@wedontknow.no": "other", - ".*\\.no": "stuffy", - }, - }, - }, - }, - args: args{ - email: "tasy@nonofthem.com", - }, - want: "", - want1: false, - }, - } - //nolint - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - app := &Headscale{ - cfg: test.fields.cfg, - db: test.fields.db, - dbString: test.fields.dbString, - dbType: test.fields.dbType, - dbDebug: test.fields.dbDebug, - privateKey: test.fields.privateKey, - aclPolicy: test.fields.aclPolicy, - aclRules: test.fields.aclRules, - lastStateChange: test.fields.lastStateChange, - oidcProvider: test.fields.oidcProvider, - oauth2Config: test.fields.oauth2Config, - oidcStateCache: test.fields.oidcStateCache, - } - got, got1 := app.getNamespaceFromEmail(test.args.email) - if got != test.want { - t.Errorf( - "Headscale.getNamespaceFromEmail() got = %v, want %v", - got, - test.want, - ) - } - if got1 != test.want1 { - t.Errorf( - "Headscale.getNamespaceFromEmail() got1 = %v, want %v", - got1, - test.want1, - ) - } - }) - } -}