diff --git a/app.go b/app.go index 020f6fc..fa1f5a0 100644 --- a/app.go +++ b/app.go @@ -107,9 +107,10 @@ type Config struct { } type OIDCConfig struct { - Issuer string - ClientID string - ClientSecret string + Issuer string + ClientID string + ClientSecret string + StripEmaildomain bool } type DERPConfig struct { diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index f738ad3..d38bb26 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -63,6 +63,8 @@ func LoadConfig(path string) error { viper.SetDefault("cli.timeout", "5s") viper.SetDefault("cli.insecure", false) + viper.SetDefault("oidc.strip_email_domain", true) + if err := viper.ReadInConfig(); err != nil { return fmt.Errorf("fatal error reading config file: %w", err) } @@ -323,9 +325,10 @@ func getHeadscaleConfig() headscale.Config { UnixSocketPermission: GetFileMode("unix_socket_permission"), OIDC: headscale.OIDCConfig{ - Issuer: viper.GetString("oidc.issuer"), - ClientID: viper.GetString("oidc.client_id"), - ClientSecret: viper.GetString("oidc.client_secret"), + Issuer: viper.GetString("oidc.issuer"), + ClientID: viper.GetString("oidc.client_id"), + ClientSecret: viper.GetString("oidc.client_secret"), + StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), }, CLI: headscale.CLIConfig{ diff --git a/config-example.yaml b/config-example.yaml index 17f556b..71fdfaa 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -180,3 +180,9 @@ unix_socket_permission: "0770" # client_id: "your-oidc-client-id" # client_secret: "your-oidc-client-secret" # +# If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. +# This will transform `first-name.last-name@example.com` to the namespace `first-name.last-name` +# If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following +# namespace: `first-name.last-name.example.com` +# +# strip_email_domain: true diff --git a/namespaces.go b/namespaces.go index cf1efe5..b02f7a7 100644 --- a/namespaces.go +++ b/namespaces.go @@ -268,10 +268,15 @@ func (n *Namespace) toProto() *v1.Namespace { // NormalizeNamespaceName will replace forbidden chars in namespace // it can also return an error if the namespace doesn't respect RFC 952 and 1123. -func NormalizeNamespaceName(name string) (string, error) { +func NormalizeNamespaceName(name string, stripEmailDomain bool) (string, error) { name = strings.ToLower(name) - name = strings.ReplaceAll(name, "@", ".") name = strings.ReplaceAll(name, "'", "") + if stripEmailDomain { + idx := strings.Index(name, "@") + name = name[:idx] + } else { + name = strings.ReplaceAll(name, "@", ".") + } name = invalidCharsInNamespaceRegex.ReplaceAllString(name, "-") for _, elt := range strings.Split(name, ".") { diff --git a/namespaces_test.go b/namespaces_test.go index 6b8df2b..6fb572c 100644 --- a/namespaces_test.go +++ b/namespaces_test.go @@ -244,7 +244,8 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { func TestNormalizeNamespaceName(t *testing.T) { type args struct { - name string + name string + stripEmailDomain bool } tests := []struct { name string @@ -253,39 +254,63 @@ func TestNormalizeNamespaceName(t *testing.T) { wantErr bool }{ { - name: "normalize simple name", - args: args{name: "normalize-simple.name"}, + name: "normalize simple name", + args: args{ + name: "normalize-simple.name", + stripEmailDomain: false, + }, want: "normalize-simple.name", wantErr: false, }, { - name: "normalize an email", - args: args{name: "foo.bar@example.com"}, + name: "normalize an email", + args: args{ + name: "foo.bar@example.com", + stripEmailDomain: false, + }, want: "foo.bar.example.com", wantErr: false, }, { - name: "normalize complex email", - args: args{name: "foo.bar+complex-email@example.com"}, + name: "normalize an email domain should be removed", + args: args{ + name: "foo.bar@example.com", + stripEmailDomain: true, + }, + want: "foo.bar", + wantErr: false, + }, + { + name: "normalize complex email", + args: args{ + name: "foo.bar+complex-email@example.com", + stripEmailDomain: false, + }, want: "foo.bar-complex-email.example.com", wantErr: false, }, { - name: "namespace name with space", - args: args{name: "name space"}, + name: "namespace name with space", + args: args{ + name: "name space", + stripEmailDomain: false, + }, want: "name-space", wantErr: false, }, { - name: "namespace with quote", - args: args{name: "Jamie's iPhone 5"}, + name: "namespace with quote", + args: args{ + name: "Jamie's iPhone 5", + stripEmailDomain: false, + }, want: "jamies-iphone-5", wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NormalizeNamespaceName(tt.args.name) + got, err := NormalizeNamespaceName(tt.args.name, tt.args.stripEmailDomain) if (err != nil) != tt.wantErr { t.Errorf( "NormalizeNamespaceName() error = %v, wantErr %v", diff --git a/oidc.go b/oidc.go index 78caa64..2036c4d 100644 --- a/oidc.go +++ b/oidc.go @@ -281,7 +281,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { now := time.Now().UTC() - namespaceName, err := NormalizeNamespaceName(claims.Email) + namespaceName, err := NormalizeNamespaceName(claims.Email, h.cfg.OIDC.StripEmaildomain) if err != nil { log.Error().Err(err).Caller().Msgf("couldn't normalize email") ctx.String(