This commit is contained in:
Tao Chen 2024-05-09 16:42:39 +02:00
parent 77c6bcacca
commit bd78f564b9
2 changed files with 36 additions and 54 deletions

View file

@ -350,7 +350,7 @@ func extractIDTokenClaims(
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token claims")) _, werr := writer.Write([]byte("Failed to decode id token claims"))
if werr != nil { if werr != nil {
util.LogErr(err,"Failed to write response") util.LogErr(err, "Failed to write response")
} }
return nil, err return nil, err
} }
@ -358,7 +358,7 @@ func extractIDTokenClaims(
// Unmarshal the claims into a map // Unmarshal the claims into a map
mappedClaims := make(map[string]interface{}) mappedClaims := make(map[string]interface{})
if err := json.Unmarshal(claims, &mappedClaims); err != nil { if err := json.Unmarshal(claims, &mappedClaims); err != nil {
util.LogErr(err,"Failed to unmarshal id token claims") util.LogErr(err, "Failed to unmarshal id token claims")
return nil, err return nil, err
} }
@ -388,24 +388,6 @@ func extractIDTokenClaims(
return &finalClaims, nil return &finalClaims, nil
} }
// {
// var claims IDTokenClaims
// if err := idToken.Claims(&claims); err != nil {
// util.LogErr(err, "Failed to decode id token claims")
// writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
// writer.WriteHeader(http.StatusBadRequest)
// _, werr := writer.Write([]byte("Failed to decode id token claims"))
// if werr != nil {
// util.LogErr(err, "Failed to write response")
// }
// return nil, err
// }
// return &claims, nil
// }
// validateOIDCAllowedDomains checks that if AllowedDomains is provided, // validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>. // that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains( func validateOIDCAllowedDomains(
@ -589,7 +571,7 @@ func getUserName(
writer http.ResponseWriter, writer http.ResponseWriter,
claims *IDTokenClaims, claims *IDTokenClaims,
stripEmaildomain bool, stripEmaildomain bool,
) (string, error) { ) (string, error) {
userName, err := util.NormalizeToFQDNRules( userName, err := util.NormalizeToFQDNRules(
claims.Username, claims.Username,
stripEmaildomain, stripEmaildomain,

View file

@ -126,34 +126,34 @@ type OIDCConfig struct {
ClientSecret string ClientSecret string
Scope []string Scope []string
ExtraParams map[string]string ExtraParams map[string]string
ClaimsMap OIDCClaimsMap ClaimsMap OIDCClaimsMap
Allowed OIDCAllowedConfig Allowed OIDCAllowedConfig
Expiry OIDCExpireConfig Expiry OIDCExpireConfig
Misc OIDCMiscConfig Misc OIDCMiscConfig
} }
type OIDCExpireConfig struct { type OIDCExpireConfig struct {
FromToken bool FromToken bool
FixedTime time.Duration FixedTime time.Duration
} }
type OIDCAllowedConfig struct { type OIDCAllowedConfig struct {
Domains []string Domains []string
Users []string Users []string
Groups []string Groups []string
} }
type OIDCClaimsMap struct { type OIDCClaimsMap struct {
Name string Name string
Username string Username string
Email string Email string
Groups string Groups string
} }
type OIDCMiscConfig struct { type OIDCMiscConfig struct {
StripEmaildomain bool StripEmaildomain bool
FlattenGroups bool FlattenGroups bool
FlattenSplter string FlattenSplter string
} }
type DERPConfig struct { type DERPConfig struct {
@ -254,9 +254,9 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.claims_map.email", "email") viper.SetDefault("oidc.claims_map.email", "email")
viper.SetDefault("oidc.claims_map.groups", "groups") viper.SetDefault("oidc.claims_map.groups", "groups")
// misc // misc
viper.SetDefault("oidc.strip_email_domain", false) viper.SetDefault("oidc.misc.strip_email_domain", false)
viper.SetDefault("oidc.flatten_groups", false) viper.SetDefault("oidc.misc.flatten_groups", false)
viper.SetDefault("oidc.flatten_splitter", "/") viper.SetDefault("oidc.misc.flatten_splitter", "/")
viper.SetDefault("logtail.enabled", false) viper.SetDefault("logtail.enabled", false)
viper.SetDefault("randomize_client_port", false) viper.SetDefault("randomize_client_port", false)
@ -695,9 +695,9 @@ func GetOIDCConfig() (OIDCConfig, error) {
} }
// get misc config // get misc config
oidcMiscConfig := OIDCMiscConfig{ oidcMiscConfig := OIDCMiscConfig{
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), StripEmaildomain: viper.GetBool("oidc.misc.strip_email_domain"),
FlattenGroups: viper.GetBool("oidc.flatten_groups"), FlattenGroups: viper.GetBool("oidc.misc.flatten_groups"),
FlattenSplter: viper.GetString("oidc.flatten_splitter"), FlattenSplter: viper.GetString("oidc.misc.flatten_splitter"),
} }
// get client secret // get client secret
oidcClientSecret := viper.GetString("oidc.client_secret") oidcClientSecret := viper.GetString("oidc.client_secret")
@ -716,15 +716,15 @@ func GetOIDCConfig() (OIDCConfig, error) {
OnlyStartIfOIDCIsAvailable: viper.GetBool( OnlyStartIfOIDCIsAvailable: viper.GetBool(
"oidc.only_start_if_oidc_is_available", "oidc.only_start_if_oidc_is_available",
), ),
Issuer: viper.GetString("oidc.issuer"), Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"), ClientID: viper.GetString("oidc.client_id"),
ClientSecret: oidcClientSecret, ClientSecret: oidcClientSecret,
Scope: viper.GetStringSlice("oidc.scope"), Scope: viper.GetStringSlice("oidc.scope"),
ExtraParams: viper.GetStringMapString("oidc.extra_params"), ExtraParams: viper.GetStringMapString("oidc.extra_params"),
Allowed: oidcAllowed, Allowed: oidcAllowed,
ClaimsMap: oidcClaimsMap, ClaimsMap: oidcClaimsMap,
Expiry: oidcExpireConfig, Expiry: oidcExpireConfig,
Misc: oidcMiscConfig, Misc: oidcMiscConfig,
} }
return OIDC, nil return OIDC, nil
} }
@ -810,9 +810,9 @@ func GetHeadscaleConfig() (*Config, error) {
UnixSocket: viper.GetString("unix_socket"), UnixSocket: viper.GetString("unix_socket"),
UnixSocketPermission: util.GetFileMode("unix_socket_permission"), UnixSocketPermission: util.GetFileMode("unix_socket_permission"),
OIDC: oidcConfig, OIDC: oidcConfig,
LogTail: logConfig, LogTail: logConfig,
RandomizeClientPort: randomizeClientPort, RandomizeClientPort: randomizeClientPort,
ACL: GetACLConfig(), ACL: GetACLConfig(),