diff --git a/CHANGELOG.md b/CHANGELOG.md index 26b575b..d44db56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,10 @@ This is a part of aligning `headscale`'s behaviour with Tailscale's upstream beh - Add API Key support - Enable remote control of `headscale` via CLI [docs](docs/remote-cli.md) - Enable HTTP API (beta, subject to change) +- OpenID Connect users will be mapped per namespaces + - Each user will get its own namespace, created if it does not exist + - `oidc.domain_map` option has been removed + - `strip_email_domain` option has been added (see [config-example.yaml](./config_example.yaml)) **Changes**: diff --git a/app.go b/app.go index 68d933c..eda670e 100644 --- a/app.go +++ b/app.go @@ -112,10 +112,10 @@ type Config struct { } type OIDCConfig struct { - Issuer string - ClientID string - ClientSecret string - MatchMap map[string]string + Issuer string + ClientID string + ClientSecret string + StripEmaildomain bool } type DERPConfig struct { diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index e3dce6b..97c2440 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -10,7 +10,6 @@ import ( "net/url" "os" "path/filepath" - "regexp" "strconv" "strings" "time" @@ -65,6 +64,8 @@ func LoadConfig(path string) error { viper.SetDefault("cli.timeout", "5s") viper.SetDefault("cli.insecure", false) + viper.SetDefault("oidc.strip_email_domain", true) + if err := viper.ReadInConfig(); err != nil { return fmt.Errorf("fatal error reading config file: %w", err) } @@ -344,9 +345,10 @@ func getHeadscaleConfig() headscale.Config { UnixSocketPermission: GetFileMode("unix_socket_permission"), OIDC: headscale.OIDCConfig{ - Issuer: viper.GetString("oidc.issuer"), - ClientID: viper.GetString("oidc.client_id"), - ClientSecret: viper.GetString("oidc.client_secret"), + Issuer: viper.GetString("oidc.issuer"), + ClientID: viper.GetString("oidc.client_id"), + ClientSecret: viper.GetString("oidc.client_secret"), + StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), }, CLI: headscale.CLIConfig{ @@ -376,8 +378,6 @@ func getHeadscaleApp() (*headscale.Headscale, error) { cfg := getHeadscaleConfig() - cfg.OIDC.MatchMap = loadOIDCMatchMap() - app, err := headscale.NewHeadscale(cfg) if err != nil { return nil, err @@ -534,18 +534,6 @@ func (tokenAuth) RequireTransportSecurity() bool { return true } -// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in -// the match map is valid regex strings. -func loadOIDCMatchMap() map[string]string { - strMap := viper.GetStringMapString("oidc.domain_map") - - for oidcMatcher := range strMap { - _ = regexp.MustCompile(oidcMatcher) - } - - return strMap -} - func GetFileMode(key string) fs.FileMode { modeStr := viper.GetString(key) diff --git a/config-example.yaml b/config-example.yaml index fe0647e..bac3b0c 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -187,7 +187,9 @@ unix_socket_permission: "0770" # client_id: "your-oidc-client-id" # client_secret: "your-oidc-client-secret" # -# # Domain map is used to map incomming users (by their email) to -# # a namespace. The key can be a string, or regex. -# domain_map: -# ".*": default-namespace +# If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. +# This will transform `first-name.last-name@example.com` to the namespace `first-name.last-name` +# If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following +# namespace: `first-name.last-name.example.com` +# +# strip_email_domain: true diff --git a/dns.go b/dns.go index 45e0fae..9a91cff 100644 --- a/dns.go +++ b/dns.go @@ -165,11 +165,7 @@ func getMapResponseDNSConfig( dnsConfig.Domains, fmt.Sprintf( "%s.%s", - strings.ReplaceAll( - machine.Namespace.Name, - "@", - ".", - ), // Replace @ with . for valid domain for machine + machine.Namespace.Name, baseDomain, ), ) diff --git a/machine.go b/machine.go index 892fa3d..dac4459 100644 --- a/machine.go +++ b/machine.go @@ -24,6 +24,11 @@ const ( errMachineAlreadyRegistered = Error("machine already registered") errMachineRouteIsNotAvailable = Error("route is not available on machine") errMachineAddressesInvalid = Error("failed to parse machine addresses") + errHostnameTooLong = Error("Hostname too long") +) + +const ( + maxHostnameLength = 255 ) // Machine is a Headscale client. @@ -595,13 +600,16 @@ func (machine Machine) toNode( hostname = fmt.Sprintf( "%s.%s.%s", machine.Name, - strings.ReplaceAll( - machine.Namespace.Name, - "@", - ".", - ), // Replace @ with . for valid domain for machine + machine.Namespace.Name, baseDomain, ) + if len(hostname) > maxHostnameLength { + return nil, fmt.Errorf( + "hostname %q is too long it cannot except 255 ASCII chars: %w", + hostname, + errHostnameTooLong, + ) + } } else { hostname = machine.Name } diff --git a/namespaces.go b/namespaces.go index 2eeee81..3151784 100644 --- a/namespaces.go +++ b/namespaces.go @@ -2,7 +2,10 @@ package headscale import ( "errors" + "fmt" + "regexp" "strconv" + "strings" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -16,8 +19,16 @@ const ( errNamespaceExists = Error("Namespace already exists") errNamespaceNotFound = Error("Namespace not found") errNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found") + errInvalidNamespaceName = Error("Invalid namespace name") ) +const ( + // value related to RFC 1123 and 952. + labelHostnameLength = 63 +) + +var invalidCharsInNamespaceRegex = regexp.MustCompile("[^a-z0-9-.]+") + // Namespace is the way Headscale implements the concept of users in Tailscale // // At the end of the day, users in Tailscale are some kind of 'bubbles' or namespaces @@ -30,6 +41,10 @@ type Namespace struct { // CreateNamespace creates a new Namespace. Returns error if could not be created // or another namespace already exists. func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { + err := CheckNamespaceName(name) + if err != nil { + return nil, err + } namespace := Namespace{} if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil { return nil, errNamespaceExists @@ -84,10 +99,15 @@ func (h *Headscale) DestroyNamespace(name string) error { // RenameNamespace renames a Namespace. Returns error if the Namespace does // not exist or if another Namespace exists with the new name. func (h *Headscale) RenameNamespace(oldName, newName string) error { + var err error oldNamespace, err := h.GetNamespace(oldName) if err != nil { return err } + err = CheckNamespaceName(newName) + if err != nil { + return err + } _, err = h.GetNamespace(newName) if err == nil { return errNamespaceExists @@ -130,6 +150,10 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) { // ListMachinesInNamespace gets all the nodes in a given namespace. func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) { + err := CheckNamespaceName(name) + if err != nil { + return nil, err + } namespace, err := h.GetNamespace(name) if err != nil { return nil, err @@ -145,6 +169,10 @@ func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) { // SetMachineNamespace assigns a Machine to a namespace. func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error { + err := CheckNamespaceName(namespaceName) + if err != nil { + return err + } namespace, err := h.GetNamespace(namespaceName) if err != nil { return err @@ -208,3 +236,55 @@ func (n *Namespace) toProto() *v1.Namespace { CreatedAt: timestamppb.New(n.CreatedAt), } } + +// NormalizeNamespaceName will replace forbidden chars in namespace +// it can also return an error if the namespace doesn't respect RFC 952 and 1123. +func NormalizeNamespaceName(name string, stripEmailDomain bool) (string, error) { + name = strings.ToLower(name) + name = strings.ReplaceAll(name, "'", "") + atIdx := strings.Index(name, "@") + if stripEmailDomain && atIdx > 0 { + name = name[:atIdx] + } else { + name = strings.ReplaceAll(name, "@", ".") + } + name = invalidCharsInNamespaceRegex.ReplaceAllString(name, "-") + + for _, elt := range strings.Split(name, ".") { + if len(elt) > labelHostnameLength { + return "", fmt.Errorf( + "label %v is more than 63 chars: %w", + elt, + errInvalidNamespaceName, + ) + } + } + + return name, nil +} + +func CheckNamespaceName(name string) error { + if len(name) > labelHostnameLength { + return fmt.Errorf( + "Namespace must not be over 63 chars. %v doesn't comply with this rule: %w", + name, + errInvalidNamespaceName, + ) + } + if strings.ToLower(name) != name { + return fmt.Errorf( + "Namespace name should be lowercase. %v doesn't comply with this rule: %w", + name, + errInvalidNamespaceName, + ) + } + if invalidCharsInNamespaceRegex.MatchString(name) { + return fmt.Errorf( + "Namespace name should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", + name, + errInvalidNamespaceName, + ) + } + + return nil +} diff --git a/namespaces_test.go b/namespaces_test.go index 77c33ad..585ba93 100644 --- a/namespaces_test.go +++ b/namespaces_test.go @@ -1,6 +1,8 @@ package headscale import ( + "testing" + "gopkg.in/check.v1" "gorm.io/gorm" "inet.af/netaddr" @@ -71,23 +73,23 @@ func (s *Suite) TestRenameNamespace(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(namespaces), check.Equals, 1) - err = app.RenameNamespace("test", "test_renamed") + err = app.RenameNamespace("test", "test-renamed") c.Assert(err, check.IsNil) _, err = app.GetNamespace("test") c.Assert(err, check.Equals, errNamespaceNotFound) - _, err = app.GetNamespace("test_renamed") + _, err = app.GetNamespace("test-renamed") c.Assert(err, check.IsNil) - err = app.RenameNamespace("test_does_not_exit", "test") + err = app.RenameNamespace("test-does-not-exit", "test") c.Assert(err, check.Equals, errNamespaceNotFound) namespaceTest2, err := app.CreateNamespace("test2") c.Assert(err, check.IsNil) c.Assert(namespaceTest2.Name, check.Equals, "test2") - err = app.RenameNamespace("test2", "test_renamed") + err = app.RenameNamespace("test2", "test-renamed") c.Assert(err, check.Equals, errNamespaceExists) } @@ -235,3 +237,143 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { } c.Assert(found, check.Equals, true) } + +func TestNormalizeNamespaceName(t *testing.T) { + type args struct { + name string + stripEmailDomain bool + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "normalize simple name", + args: args{ + name: "normalize-simple.name", + stripEmailDomain: false, + }, + want: "normalize-simple.name", + wantErr: false, + }, + { + name: "normalize an email", + args: args{ + name: "foo.bar@example.com", + stripEmailDomain: false, + }, + want: "foo.bar.example.com", + wantErr: false, + }, + { + name: "normalize an email domain should be removed", + args: args{ + name: "foo.bar@example.com", + stripEmailDomain: true, + }, + want: "foo.bar", + wantErr: false, + }, + { + name: "strip enabled no email passed as argument", + args: args{ + name: "not-email-and-strip-enabled", + stripEmailDomain: true, + }, + want: "not-email-and-strip-enabled", + wantErr: false, + }, + { + name: "normalize complex email", + args: args{ + name: "foo.bar+complex-email@example.com", + stripEmailDomain: false, + }, + want: "foo.bar-complex-email.example.com", + wantErr: false, + }, + { + name: "namespace name with space", + args: args{ + name: "name space", + stripEmailDomain: false, + }, + want: "name-space", + wantErr: false, + }, + { + name: "namespace with quote", + args: args{ + name: "Jamie's iPhone 5", + stripEmailDomain: false, + }, + want: "jamies-iphone-5", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NormalizeNamespaceName(tt.args.name, tt.args.stripEmailDomain) + if (err != nil) != tt.wantErr { + t.Errorf( + "NormalizeNamespaceName() error = %v, wantErr %v", + err, + tt.wantErr, + ) + + return + } + if got != tt.want { + t.Errorf("NormalizeNamespaceName() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCheckNamespaceName(t *testing.T) { + type args struct { + name string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "valid: namespace", + args: args{name: "valid-namespace"}, + wantErr: false, + }, + { + name: "invalid: capitalized namespace", + args: args{name: "Invalid-CapItaLIzed-namespace"}, + wantErr: true, + }, + { + name: "invalid: email as namespace", + args: args{name: "foo.bar@example.com"}, + wantErr: true, + }, + { + name: "invalid: chars in namespace name", + args: args{name: "super-namespace+name"}, + wantErr: true, + }, + { + name: "invalid: too long name for namespace", + args: args{ + name: "super-long-namespace-name-that-should-be-a-little-more-than-63-chars", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := CheckNamespaceName(tt.args.name); (err != nil) != tt.wantErr { + t.Errorf("CheckNamespaceName() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/oidc.go b/oidc.go index cd77d29..edea32a 100644 --- a/oidc.go +++ b/oidc.go @@ -9,7 +9,6 @@ import ( "fmt" "html/template" "net/http" - "regexp" "strings" "time" @@ -282,113 +281,92 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { now := time.Now().UTC() - if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok { - // register the machine if it's new - if !machine.Registered { - log.Debug().Msg("Registering new machine after successful callback") - - namespace, err := h.GetNamespace(namespaceName) - if errors.Is(err, gorm.ErrRecordNotFound) { - namespace, err = h.CreateNamespace(namespaceName) - - if err != nil { - log.Error(). - Err(err). - Caller(). - Msgf("could not create new namespace '%s'", namespaceName) - ctx.String( - http.StatusInternalServerError, - "could not create new namespace", - ) - - return - } - } else if err != nil { - log.Error(). - Caller(). - Err(err). - Str("namespace", namespaceName). - Msg("could not find or create namespace") - ctx.String( - http.StatusInternalServerError, - "could not find or create namespace", - ) - - return - } - - h.ipAllocationMutex.Lock() - - ips, err := h.getAvailableIPs() - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("could not get an IP from the pool") - ctx.String( - http.StatusInternalServerError, - "could not get an IP from the pool", - ) - - return - } - - machine.IPAddresses = ips - machine.NamespaceID = namespace.ID - machine.Registered = true - machine.RegisterMethod = RegisterMethodOIDC - machine.LastSuccessfulUpdate = &now - machine.Expiry = &requestedTime - h.db.Save(&machine) - - h.ipAllocationMutex.Unlock() - } - - var content bytes.Buffer - if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ - User: claims.Email, - Verb: "Authenticated", - }); err != nil { - log.Error(). - Str("func", "OIDCCallback"). - Str("type", "authenticate"). - Err(err). - Msg("Could not render OIDC callback template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render OIDC callback template"), - ) - } - - ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) + namespaceName, err := NormalizeNamespaceName( + claims.Email, + h.cfg.OIDC.StripEmaildomain, + ) + if err != nil { + log.Error().Err(err).Caller().Msgf("couldn't normalize email") + ctx.String( + http.StatusInternalServerError, + "couldn't normalize email", + ) return } + // register the machine if it's new + if !machine.Registered { + log.Debug().Msg("Registering new machine after successful callback") - log.Error(). - Caller(). - Str("email", claims.Email). - Str("username", claims.Username). - Str("machine", machine.Name). - Msg("Email could not be mapped to a namespace") - ctx.String( - http.StatusBadRequest, - "email from claim could not be mapped to a namespace", - ) -} + namespace, err := h.GetNamespace(namespaceName) + if errors.Is(err, gorm.ErrRecordNotFound) { + namespace, err = h.CreateNamespace(namespaceName) -// getNamespaceFromEmail passes the users email through a list of "matchers" -// and iterates through them until it matches and returns a namespace. -// If no match is found, an empty string will be returned. -// TODO(kradalby): golang Maps key order is not stable, so this list is _not_ deterministic. Find a way to make the list of keys stable, preferably in the order presented in a users configuration. -func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) { - for match, namespace := range h.cfg.OIDC.MatchMap { - regex := regexp.MustCompile(match) - if regex.MatchString(email) { - return namespace, true + if err != nil { + log.Error(). + Err(err). + Caller(). + Msgf("could not create new namespace '%s'", namespaceName) + ctx.String( + http.StatusInternalServerError, + "could not create new namespace", + ) + + return + } + } else if err != nil { + log.Error(). + Caller(). + Err(err). + Str("namespace", namespaceName). + Msg("could not find or create namespace") + ctx.String( + http.StatusInternalServerError, + "could not find or create namespace", + ) + + return } + + ips, err := h.getAvailableIPs() + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("could not get an IP from the pool") + ctx.String( + http.StatusInternalServerError, + "could not get an IP from the pool", + ) + + return + } + + machine.IPAddresses = ips + machine.NamespaceID = namespace.ID + machine.Registered = true + machine.RegisterMethod = RegisterMethodOIDC + machine.LastSuccessfulUpdate = &now + machine.Expiry = &requestedTime + h.db.Save(&machine) } - return "", false + var content bytes.Buffer + if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ + User: claims.Email, + Verb: "Authenticated", + }); err != nil { + log.Error(). + Str("func", "OIDCCallback"). + Str("type", "authenticate"). + Err(err). + Msg("Could not render OIDC callback template") + ctx.Data( + http.StatusInternalServerError, + "text/html; charset=utf-8", + []byte("Could not render OIDC callback template"), + ) + } + + ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) } diff --git a/oidc_test.go b/oidc_test.go deleted file mode 100644 index d50027a..0000000 --- a/oidc_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package headscale - -import ( - "sync" - "testing" - - "github.com/coreos/go-oidc/v3/oidc" - "github.com/patrickmn/go-cache" - "golang.org/x/oauth2" - "gorm.io/gorm" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -func TestHeadscale_getNamespaceFromEmail(t *testing.T) { - type fields struct { - cfg Config - db *gorm.DB - dbString string - dbType string - dbDebug bool - privateKey *key.MachinePrivate - aclPolicy *ACLPolicy - aclRules []tailcfg.FilterRule - lastStateChange sync.Map - oidcProvider *oidc.Provider - oauth2Config *oauth2.Config - oidcStateCache *cache.Cache - } - type args struct { - email string - } - tests := []struct { - name string - fields fields - args args - want string - want1 bool - }{ - { - name: "match all", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - ".*": "space", - }, - }, - }, - }, - args: args{ - email: "test@example.no", - }, - want: "space", - want1: true, - }, - { - name: "match user", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - "specific@user\\.no": "user-namespace", - }, - }, - }, - }, - args: args{ - email: "specific@user.no", - }, - want: "user-namespace", - want1: true, - }, - { - name: "match domain", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - ".*@example\\.no": "example", - }, - }, - }, - }, - args: args{ - email: "test@example.no", - }, - want: "example", - want1: true, - }, - { - name: "multi match domain", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - ".*@example\\.no": "exammple", - ".*@gmail\\.com": "gmail", - }, - }, - }, - }, - args: args{ - email: "someuser@gmail.com", - }, - want: "gmail", - want1: true, - }, - { - name: "no match domain", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - ".*@dontknow.no": "never", - }, - }, - }, - }, - args: args{ - email: "test@wedontknow.no", - }, - want: "", - want1: false, - }, - { - name: "multi no match domain", - fields: fields{ - cfg: Config{ - OIDC: OIDCConfig{ - MatchMap: map[string]string{ - ".*@dontknow.no": "never", - ".*@wedontknow.no": "other", - ".*\\.no": "stuffy", - }, - }, - }, - }, - args: args{ - email: "tasy@nonofthem.com", - }, - want: "", - want1: false, - }, - } - //nolint - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - app := &Headscale{ - cfg: test.fields.cfg, - db: test.fields.db, - dbString: test.fields.dbString, - dbType: test.fields.dbType, - dbDebug: test.fields.dbDebug, - privateKey: test.fields.privateKey, - aclPolicy: test.fields.aclPolicy, - aclRules: test.fields.aclRules, - lastStateChange: test.fields.lastStateChange, - oidcProvider: test.fields.oidcProvider, - oauth2Config: test.fields.oauth2Config, - oidcStateCache: test.fields.oidcStateCache, - } - got, got1 := app.getNamespaceFromEmail(test.args.email) - if got != test.want { - t.Errorf( - "Headscale.getNamespaceFromEmail() got = %v, want %v", - got, - test.want, - ) - } - if got1 != test.want1 { - t.Errorf( - "Headscale.getNamespaceFromEmail() got1 = %v, want %v", - got1, - test.want1, - ) - } - }) - } -} diff --git a/utils_test.go b/utils_test.go index 896040c..271013a 100644 --- a/utils_test.go +++ b/utils_test.go @@ -20,7 +20,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) { ips, err := app.getAvailableIPs() c.Assert(err, check.IsNil) - namespace, err := app.CreateNamespace("test_ip") + namespace, err := app.CreateNamespace("test-ip") c.Assert(err, check.IsNil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) @@ -160,7 +160,7 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { c.Assert(len(ips), check.Equals, 1) c.Assert(ips[0].String(), check.Equals, expected.String()) - namespace, err := app.CreateNamespace("test_ip") + namespace, err := app.CreateNamespace("test-ip") c.Assert(err, check.IsNil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)