From bd78f564b9210cf41caf3c70c4340dadac3a10d8 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Thu, 9 May 2024 16:42:39 +0200 Subject: [PATCH] fix bug --- hscontrol/oidc.go | 24 ++------------ hscontrol/types/config.go | 66 +++++++++++++++++++-------------------- 2 files changed, 36 insertions(+), 54 deletions(-) diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 95ca39a..c6b3ef2 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -350,7 +350,7 @@ func extractIDTokenClaims( writer.WriteHeader(http.StatusBadRequest) _, werr := writer.Write([]byte("Failed to decode id token claims")) if werr != nil { - util.LogErr(err,"Failed to write response") + util.LogErr(err, "Failed to write response") } return nil, err } @@ -358,7 +358,7 @@ func extractIDTokenClaims( // Unmarshal the claims into a map mappedClaims := make(map[string]interface{}) if err := json.Unmarshal(claims, &mappedClaims); err != nil { - util.LogErr(err,"Failed to unmarshal id token claims") + util.LogErr(err, "Failed to unmarshal id token claims") return nil, err } @@ -388,24 +388,6 @@ func extractIDTokenClaims( return &finalClaims, nil } -// { -// var claims IDTokenClaims -// if err := idToken.Claims(&claims); err != nil { -// util.LogErr(err, "Failed to decode id token claims") - -// writer.Header().Set("Content-Type", "text/plain; charset=utf-8") -// writer.WriteHeader(http.StatusBadRequest) -// _, werr := writer.Write([]byte("Failed to decode id token claims")) -// if werr != nil { -// util.LogErr(err, "Failed to write response") -// } - -// return nil, err -// } - -// return &claims, nil -// } - // validateOIDCAllowedDomains checks that if AllowedDomains is provided, // that the authenticated principal ends with @. func validateOIDCAllowedDomains( @@ -589,7 +571,7 @@ func getUserName( writer http.ResponseWriter, claims *IDTokenClaims, stripEmaildomain bool, -) (string, error) { +) (string, error) { userName, err := util.NormalizeToFQDNRules( claims.Username, stripEmaildomain, diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index b62df54..ecdaf51 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -126,34 +126,34 @@ type OIDCConfig struct { ClientSecret string Scope []string ExtraParams map[string]string - ClaimsMap OIDCClaimsMap - Allowed OIDCAllowedConfig + ClaimsMap OIDCClaimsMap + Allowed OIDCAllowedConfig Expiry OIDCExpireConfig - Misc OIDCMiscConfig + Misc OIDCMiscConfig } type OIDCExpireConfig struct { - FromToken bool - FixedTime time.Duration + FromToken bool + FixedTime time.Duration } type OIDCAllowedConfig struct { - Domains []string - Users []string - Groups []string + Domains []string + Users []string + Groups []string } type OIDCClaimsMap struct { - Name string - Username string - Email string - Groups string + Name string + Username string + Email string + Groups string } type OIDCMiscConfig struct { - StripEmaildomain bool - FlattenGroups bool - FlattenSplter string + StripEmaildomain bool + FlattenGroups bool + FlattenSplter string } type DERPConfig struct { @@ -254,9 +254,9 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("oidc.claims_map.email", "email") viper.SetDefault("oidc.claims_map.groups", "groups") // misc - viper.SetDefault("oidc.strip_email_domain", false) - viper.SetDefault("oidc.flatten_groups", false) - viper.SetDefault("oidc.flatten_splitter", "/") + viper.SetDefault("oidc.misc.strip_email_domain", false) + viper.SetDefault("oidc.misc.flatten_groups", false) + viper.SetDefault("oidc.misc.flatten_splitter", "/") viper.SetDefault("logtail.enabled", false) viper.SetDefault("randomize_client_port", false) @@ -695,9 +695,9 @@ func GetOIDCConfig() (OIDCConfig, error) { } // get misc config oidcMiscConfig := OIDCMiscConfig{ - StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), - FlattenGroups: viper.GetBool("oidc.flatten_groups"), - FlattenSplter: viper.GetString("oidc.flatten_splitter"), + StripEmaildomain: viper.GetBool("oidc.misc.strip_email_domain"), + FlattenGroups: viper.GetBool("oidc.misc.flatten_groups"), + FlattenSplter: viper.GetString("oidc.misc.flatten_splitter"), } // get client secret oidcClientSecret := viper.GetString("oidc.client_secret") @@ -716,15 +716,15 @@ func GetOIDCConfig() (OIDCConfig, error) { OnlyStartIfOIDCIsAvailable: viper.GetBool( "oidc.only_start_if_oidc_is_available", ), - Issuer: viper.GetString("oidc.issuer"), - ClientID: viper.GetString("oidc.client_id"), - ClientSecret: oidcClientSecret, - Scope: viper.GetStringSlice("oidc.scope"), - ExtraParams: viper.GetStringMapString("oidc.extra_params"), - Allowed: oidcAllowed, - ClaimsMap: oidcClaimsMap, - Expiry: oidcExpireConfig, - Misc: oidcMiscConfig, + Issuer: viper.GetString("oidc.issuer"), + ClientID: viper.GetString("oidc.client_id"), + ClientSecret: oidcClientSecret, + Scope: viper.GetStringSlice("oidc.scope"), + ExtraParams: viper.GetStringMapString("oidc.extra_params"), + Allowed: oidcAllowed, + ClaimsMap: oidcClaimsMap, + Expiry: oidcExpireConfig, + Misc: oidcMiscConfig, } return OIDC, nil } @@ -810,9 +810,9 @@ func GetHeadscaleConfig() (*Config, error) { UnixSocket: viper.GetString("unix_socket"), UnixSocketPermission: util.GetFileMode("unix_socket_permission"), - OIDC: oidcConfig, - LogTail: logConfig, - RandomizeClientPort: randomizeClientPort, + OIDC: oidcConfig, + LogTail: logConfig, + RandomizeClientPort: randomizeClientPort, ACL: GetACLConfig(),