re-construct oidc config
This commit is contained in:
parent
2bac80cfbf
commit
890d6e73fb
4 changed files with 210 additions and 71 deletions
|
@ -6,6 +6,7 @@ import (
|
|||
"crypto/rand"
|
||||
_ "embed"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
|
@ -77,11 +78,11 @@ func (h *Headscale) initOIDC() error {
|
|||
}
|
||||
|
||||
func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.Time {
|
||||
if h.cfg.OIDC.UseExpiryFromToken {
|
||||
if h.cfg.OIDC.Expiry.FromToken {
|
||||
return idTokenExpiration
|
||||
}
|
||||
|
||||
return time.Now().Add(h.cfg.OIDC.Expiry)
|
||||
return time.Now().Add(h.cfg.OIDC.Expiry.FixedTime)
|
||||
}
|
||||
|
||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||
|
@ -197,20 +198,20 @@ func (h *Headscale) OIDCCallback(
|
|||
// return
|
||||
// }
|
||||
|
||||
claims, err := extractIDTokenClaims(writer, idToken)
|
||||
claims, err := extractIDTokenClaims(writer, idToken, h.cfg.OIDC.ClaimsMap, h.cfg.OIDC.Misc.FlattenGroups, h.cfg.OIDC.Misc.FlattenSplter)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil {
|
||||
if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.Allowed.Domains, claims); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateOIDCAllowedGroups(writer, h.cfg.OIDC.AllowedGroups, claims); err != nil {
|
||||
if err := validateOIDCAllowedGroups(writer, h.cfg.OIDC.Allowed.Groups, claims); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil {
|
||||
if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.Allowed.Users, claims); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -223,8 +224,7 @@ func (h *Headscale) OIDCCallback(
|
|||
if err != nil || nodeExists {
|
||||
return
|
||||
}
|
||||
|
||||
userName, err := getUserName(writer, claims, h.cfg.OIDC.StripEmaildomain)
|
||||
userName, err := getUserName(writer, claims, h.cfg.OIDC.Misc.StripEmaildomain)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -338,24 +338,74 @@ func (h *Headscale) verifyIDTokenForOIDCCallback(
|
|||
func extractIDTokenClaims(
|
||||
writer http.ResponseWriter,
|
||||
idToken *oidc.IDToken,
|
||||
claimsMap types.OIDCClaimsMap,
|
||||
flattenGroup bool,
|
||||
flattenSpliter string,
|
||||
) (*IDTokenClaims, error) {
|
||||
var claims IDTokenClaims
|
||||
var claims json.RawMessage
|
||||
// Parse the ID Token claims into the struct
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
util.LogErr(err, "Failed to decode id token claims")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, werr := writer.Write([]byte("Failed to decode id token claims"))
|
||||
if werr != nil {
|
||||
util.LogErr(err,"Failed to write response")
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
// Unmarshal the claims into a map
|
||||
mappedClaims := make(map[string]interface{})
|
||||
if err := json.Unmarshal(claims, &mappedClaims); err != nil {
|
||||
util.LogErr(err,"Failed to unmarshal id token claims")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Map the claims to the final struct
|
||||
var finalClaims IDTokenClaims
|
||||
if val, ok := mappedClaims[claimsMap.Name]; ok {
|
||||
finalClaims.Name = val.(string)
|
||||
}
|
||||
if val, ok := mappedClaims[claimsMap.Username]; ok {
|
||||
finalClaims.Username = val.(string)
|
||||
}
|
||||
if val, ok := mappedClaims[claimsMap.Email]; ok {
|
||||
finalClaims.Email = val.(string)
|
||||
}
|
||||
if val, ok := mappedClaims[claimsMap.Groups]; ok && val != nil {
|
||||
groups, ok := val.([]interface{})
|
||||
if ok {
|
||||
for _, group := range groups {
|
||||
finalClaims.Groups = append(finalClaims.Groups, group.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
// Flatten groups if needed
|
||||
if flattenGroup {
|
||||
finalClaims.Groups = flattenGroups(finalClaims.Groups, flattenSpliter)
|
||||
}
|
||||
return &finalClaims, nil
|
||||
}
|
||||
|
||||
// {
|
||||
// var claims IDTokenClaims
|
||||
// if err := idToken.Claims(&claims); err != nil {
|
||||
// util.LogErr(err, "Failed to decode id token claims")
|
||||
|
||||
// writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
// writer.WriteHeader(http.StatusBadRequest)
|
||||
// _, werr := writer.Write([]byte("Failed to decode id token claims"))
|
||||
// if werr != nil {
|
||||
// util.LogErr(err, "Failed to write response")
|
||||
// }
|
||||
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// return &claims, nil
|
||||
// }
|
||||
|
||||
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
|
||||
// that the authenticated principal ends with @<alloweddomain>.
|
||||
func validateOIDCAllowedDomains(
|
||||
|
@ -541,7 +591,7 @@ func getUserName(
|
|||
stripEmaildomain bool,
|
||||
) (string, error) {
|
||||
userName, err := util.NormalizeToFQDNRules(
|
||||
claims.Email,
|
||||
claims.Username,
|
||||
stripEmaildomain,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -653,3 +703,25 @@ func renderOIDCCallbackTemplate(
|
|||
|
||||
return &content, nil
|
||||
}
|
||||
|
||||
// flattenGroups takes a list of groups and returns a list of all groups and subgroups.
|
||||
// groups format is a list of strings with the groups separated by slashes. e.g.: ["a/b/c", "a/b/d"]
|
||||
func flattenGroups(groups []string, spliter string) []string {
|
||||
// A map to keep track of which groups we have seen
|
||||
seen := make(map[string]bool)
|
||||
var result []string
|
||||
|
||||
// Iterate over each group, format is a/b/c
|
||||
for _, group := range groups {
|
||||
// Split the group into segments, e.g. ["a", "b", "c"]
|
||||
segments := strings.Split(group, spliter)
|
||||
for _, segment := range segments {
|
||||
if !seen[segment] && segment != "" {
|
||||
seen[segment] = true
|
||||
result = append(result, segment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
|
|
@ -47,8 +47,10 @@ func (s *Suite) ResetDB(c *check.C) {
|
|||
},
|
||||
},
|
||||
OIDC: types.OIDCConfig{
|
||||
Misc: types.OIDCMiscConfig{
|
||||
StripEmaildomain: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app, err = NewHeadscale(&cfg)
|
||||
|
|
|
@ -126,12 +126,34 @@ type OIDCConfig struct {
|
|||
ClientSecret string
|
||||
Scope []string
|
||||
ExtraParams map[string]string
|
||||
AllowedDomains []string
|
||||
AllowedUsers []string
|
||||
AllowedGroups []string
|
||||
ClaimsMap OIDCClaimsMap
|
||||
Allowed OIDCAllowedConfig
|
||||
Expiry OIDCExpireConfig
|
||||
Misc OIDCMiscConfig
|
||||
}
|
||||
|
||||
type OIDCExpireConfig struct {
|
||||
FromToken bool
|
||||
FixedTime time.Duration
|
||||
}
|
||||
|
||||
type OIDCAllowedConfig struct {
|
||||
Domains []string
|
||||
Users []string
|
||||
Groups []string
|
||||
}
|
||||
|
||||
type OIDCClaimsMap struct {
|
||||
Name string
|
||||
Username string
|
||||
Email string
|
||||
Groups string
|
||||
}
|
||||
|
||||
type OIDCMiscConfig struct {
|
||||
StripEmaildomain bool
|
||||
Expiry time.Duration
|
||||
UseExpiryFromToken bool
|
||||
FlattenGroups bool
|
||||
FlattenSplter string
|
||||
}
|
||||
|
||||
type DERPConfig struct {
|
||||
|
@ -222,10 +244,19 @@ func LoadConfig(path string, isFile bool) error {
|
|||
viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600)
|
||||
|
||||
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
||||
viper.SetDefault("oidc.strip_email_domain", true)
|
||||
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
||||
viper.SetDefault("oidc.expiry", "180d")
|
||||
viper.SetDefault("oidc.use_expiry_from_token", false)
|
||||
// expiry
|
||||
viper.SetDefault("oidc.expiry.fixed_time", "180d")
|
||||
viper.SetDefault("oidc.expiry.from_token", false)
|
||||
// claims_map
|
||||
viper.SetDefault("oidc.claims_map.name", "name")
|
||||
viper.SetDefault("oidc.claims_map.username", "preferred_username")
|
||||
viper.SetDefault("oidc.claims_map.email", "email")
|
||||
viper.SetDefault("oidc.claims_map.groups", "groups")
|
||||
// misc
|
||||
viper.SetDefault("oidc.strip_email_domain", false)
|
||||
viper.SetDefault("oidc.flatten_groups", false)
|
||||
viper.SetDefault("oidc.flatten_splitter", "/")
|
||||
|
||||
viper.SetDefault("logtail.enabled", false)
|
||||
viper.SetDefault("randomize_client_port", false)
|
||||
|
@ -628,6 +659,76 @@ func PrefixV6() (*netip.Prefix, error) {
|
|||
return &prefixV6, nil
|
||||
}
|
||||
|
||||
func GetOIDCConfig() (OIDCConfig, error) {
|
||||
|
||||
// get expiry config
|
||||
expireConfig := OIDCExpireConfig{
|
||||
FromToken: viper.GetBool("oidc.expiry.from_token"),
|
||||
FixedTime: func() time.Duration {
|
||||
// if set to 0, we assume no expiry
|
||||
if value := viper.GetString("oidc.expiry.fixed_time"); value == "0" {
|
||||
return maxDuration
|
||||
} else {
|
||||
expiry, err := model.ParseDuration(value)
|
||||
if err != nil {
|
||||
log.Warn().Msg("failed to parse oidc.expiry.fixed_time, defaulting back to 180 days")
|
||||
|
||||
return defaultOIDCExpiryTime
|
||||
}
|
||||
|
||||
return time.Duration(expiry)
|
||||
}
|
||||
}(),
|
||||
}
|
||||
// get allowed config
|
||||
allowedConfig := OIDCAllowedConfig{
|
||||
Domains: viper.GetStringSlice("oidc.allowed.domains"),
|
||||
Users: viper.GetStringSlice("oidc.allowed.users"),
|
||||
Groups: viper.GetStringSlice("oidc.allowed.groups"),
|
||||
}
|
||||
// get claims map
|
||||
claimsMap := OIDCClaimsMap{
|
||||
Name: viper.GetString("oidc.claims_map.name"),
|
||||
Username: viper.GetString("oidc.claims_map.username"),
|
||||
Email: viper.GetString("oidc.claims_map.email"),
|
||||
Groups: viper.GetString("oidc.claims_map.groups"),
|
||||
}
|
||||
// get misc config
|
||||
miscConfig := OIDCMiscConfig{
|
||||
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
||||
FlattenGroups: viper.GetBool("oidc.flatten_groups"),
|
||||
FlattenSplter: viper.GetString("oidc.flatten_splitter"),
|
||||
}
|
||||
// get client secret
|
||||
oidcClientSecret := viper.GetString("oidc.client_secret")
|
||||
oidcClientSecretPath := viper.GetString("oidc.client_secret_path")
|
||||
if oidcClientSecretPath != "" && oidcClientSecret != "" {
|
||||
return OIDCConfig{}, errOidcMutuallyExclusive
|
||||
}
|
||||
if oidcClientSecretPath != "" {
|
||||
secretBytes, err := os.ReadFile(os.ExpandEnv(oidcClientSecretPath))
|
||||
if err != nil {
|
||||
return OIDCConfig{}, err
|
||||
}
|
||||
oidcClientSecret = strings.TrimSpace(string(secretBytes))
|
||||
}
|
||||
OIDC := OIDCConfig{
|
||||
OnlyStartIfOIDCIsAvailable: viper.GetBool(
|
||||
"oidc.only_start_if_oidc_is_available",
|
||||
),
|
||||
Issuer: viper.GetString("oidc.issuer"),
|
||||
ClientID: viper.GetString("oidc.client_id"),
|
||||
ClientSecret: oidcClientSecret,
|
||||
Scope: viper.GetStringSlice("oidc.scope"),
|
||||
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
|
||||
Allowed: allowedConfig,
|
||||
ClaimsMap: claimsMap,
|
||||
Expiry: expireConfig,
|
||||
Misc: miscConfig,
|
||||
}
|
||||
return OIDC, nil
|
||||
}
|
||||
|
||||
func GetHeadscaleConfig() (*Config, error) {
|
||||
if IsCLIConfigured() {
|
||||
return &Config{
|
||||
|
@ -670,18 +771,10 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
logConfig := GetLogTailConfig()
|
||||
randomizeClientPort := viper.GetBool("randomize_client_port")
|
||||
|
||||
oidcClientSecret := viper.GetString("oidc.client_secret")
|
||||
oidcClientSecretPath := viper.GetString("oidc.client_secret_path")
|
||||
if oidcClientSecretPath != "" && oidcClientSecret != "" {
|
||||
return nil, errOidcMutuallyExclusive
|
||||
}
|
||||
if oidcClientSecretPath != "" {
|
||||
secretBytes, err := os.ReadFile(os.ExpandEnv(oidcClientSecretPath))
|
||||
oidcConfig, err := GetOIDCConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oidcClientSecret = strings.TrimSpace(string(secretBytes))
|
||||
}
|
||||
|
||||
return &Config{
|
||||
ServerURL: viper.GetString("server_url"),
|
||||
|
@ -717,38 +810,7 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
|
||||
UnixSocket: viper.GetString("unix_socket"),
|
||||
UnixSocketPermission: util.GetFileMode("unix_socket_permission"),
|
||||
|
||||
OIDC: OIDCConfig{
|
||||
OnlyStartIfOIDCIsAvailable: viper.GetBool(
|
||||
"oidc.only_start_if_oidc_is_available",
|
||||
),
|
||||
Issuer: viper.GetString("oidc.issuer"),
|
||||
ClientID: viper.GetString("oidc.client_id"),
|
||||
ClientSecret: oidcClientSecret,
|
||||
Scope: viper.GetStringSlice("oidc.scope"),
|
||||
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
|
||||
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
|
||||
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
||||
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
||||
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
||||
Expiry: func() time.Duration {
|
||||
// if set to 0, we assume no expiry
|
||||
if value := viper.GetString("oidc.expiry"); value == "0" {
|
||||
return maxDuration
|
||||
} else {
|
||||
expiry, err := model.ParseDuration(value)
|
||||
if err != nil {
|
||||
log.Warn().Msg("failed to parse oidc.expiry, defaulting back to 180 days")
|
||||
|
||||
return defaultOIDCExpiryTime
|
||||
}
|
||||
|
||||
return time.Duration(expiry)
|
||||
}
|
||||
}(),
|
||||
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
|
||||
},
|
||||
|
||||
OIDC: oidcConfig,
|
||||
LogTail: logConfig,
|
||||
RandomizeClientPort: randomizeClientPort,
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
|||
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
|
||||
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.Misc.StripEmaildomain),
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
|
@ -121,7 +121,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
|||
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||
"HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret,
|
||||
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
|
||||
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.Misc.StripEmaildomain),
|
||||
"HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1",
|
||||
}
|
||||
|
||||
|
@ -269,6 +269,9 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
|
|||
|
||||
log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint)
|
||||
|
||||
oidcMisc := types.OIDCMiscConfig{
|
||||
StripEmaildomain: true,
|
||||
}
|
||||
return &types.OIDCConfig{
|
||||
Issuer: fmt.Sprintf(
|
||||
"http://%s/oidc",
|
||||
|
@ -276,7 +279,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
|
|||
),
|
||||
ClientID: "superclient",
|
||||
ClientSecret: "supersecret",
|
||||
StripEmaildomain: true,
|
||||
Misc: oidcMisc,
|
||||
OnlyStartIfOIDCIsAvailable: true,
|
||||
}, nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue