From 9d5848990348c4107859ed3f5844390d7211a4cd Mon Sep 17 00:00:00 2001 From: fen4o <1710500+fen4o@users.noreply.github.com> Date: Wed, 8 Nov 2023 13:32:47 +0200 Subject: [PATCH] Add OIDC claim names options Some identity providers (auth0 for example) do not allow to set the groups claims and administrators must use custom claims names and add them in the id token. This commit adds the following configuration options: - `oidc.groups_claim` to set the groups claim name - `oidc.email_claim` to set the email claim name All claims default to the previous values for backwards compatibility. The groups claim can now also accept `[]string` or `string` as some providers might return only a string response instead of array. --- CHANGELOG.md | 1 + docs/oidc.md | 5 ++ hscontrol/oidc.go | 88 ++++++++++++++++---- hscontrol/oidc_test.go | 164 ++++++++++++++++++++++++++++++++++++++ hscontrol/types/config.go | 6 ++ 5 files changed, 249 insertions(+), 15 deletions(-) create mode 100644 hscontrol/oidc_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 0880a12..d0e3de1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ Use error group handling to ensure tests actually pass [#1535](https://github.co Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) +Add `oidc.groups_claim` and `oidc.email_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594) ## 0.22.3 (2023-05-12) diff --git a/docs/oidc.md b/docs/oidc.md index 189d7cd..ac4f692 100644 --- a/docs/oidc.md +++ b/docs/oidc.md @@ -24,6 +24,11 @@ oidc: # It resolves environment variables, making integration to systemd's # `LoadCredential` straightforward: #client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret" + # If provided, the name of a custom OIDC claim for specifying user groups. + # The claim value is expected to be a string or array of strings. + groups_claim: groups + # The OIDC claim to use as the email. + email_claim: email # Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query # parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email". diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index b32d751..b594915 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -6,6 +6,7 @@ import ( "crypto/rand" _ "embed" "encoding/hex" + "encoding/json" "errors" "fmt" "html/template" @@ -40,14 +41,45 @@ var ( errOIDCInvalidNodeState = errors.New( "requested node state key expired before authorisation completed", ) - errOIDCNodeKeyMissing = errors.New("could not get node key from cache") + errOIDCNodeKeyMissing = errors.New("could not get node key from cache") + errOIDCEmailClaimMissing = errors.New("email claim missing from ID Token") ) type IDTokenClaims struct { - Name string `json:"name,omitempty"` - Groups []string `json:"groups,omitempty"` - Email string `json:"email"` - Username string `json:"preferred_username,omitempty"` + // in some cases the groups might be a single value and not a list + Groups stringOrArray + Email string +} + +type stringOrArray []string + +func (s *stringOrArray) UnmarshalJSON(b []byte) error { + var a []string + if err := json.Unmarshal(b, &a); err == nil { + *s = a + return nil + } + var str string + if err := json.Unmarshal(b, &str); err != nil { + return err + } + *s = []string{str} + return nil +} + +type rawClaims map[string]json.RawMessage + +func (c rawClaims) unmarshalClaim(name string, v interface{}) error { + val, ok := c[name] + if !ok { + return fmt.Errorf("claim not present") + } + return json.Unmarshal([]byte(val), v) +} + +func (c rawClaims) hasClaim(name string) bool { + _, ok := c[name] + return ok } func (h *Headscale) initOIDC() error { @@ -215,7 +247,7 @@ func (h *Headscale) OIDCCallback( // return // } - claims, err := extractIDTokenClaims(writer, idToken) + claims, err := extractIDTokenClaims(writer, h.cfg.OIDC, idToken) if err != nil { return } @@ -355,25 +387,51 @@ func (h *Headscale) verifyIDTokenForOIDCCallback( func extractIDTokenClaims( writer http.ResponseWriter, + cfg types.OIDCConfig, idToken *oidc.IDToken, ) (*IDTokenClaims, error) { var claims IDTokenClaims - if err := idToken.Claims(&claims); err != nil { - util.LogErr(err, "Failed to decode id token claims") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, werr := writer.Write([]byte("Failed to decode id token claims")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } + var rawClaims rawClaims + if err := idToken.Claims(&rawClaims); err != nil { + handleClaimError(writer, err) return nil, err } + if !rawClaims.hasClaim(cfg.EmailClaim) { + handleClaimError(writer, errOIDCEmailClaimMissing) + + return nil, errOIDCEmailClaimMissing + } + + if err := rawClaims.unmarshalClaim(cfg.EmailClaim, &claims.Email); err != nil { + handleClaimError(writer, err) + + return nil, err + } + + if rawClaims.hasClaim(cfg.GroupsClaim) { + if err := rawClaims.unmarshalClaim(cfg.GroupsClaim, &claims.Groups); err != nil { + handleClaimError(writer, err) + + return nil, err + } + } + return &claims, nil } +func handleClaimError(writer http.ResponseWriter, err error) { + util.LogErr(err, "Failed to decode id token rawClaims") + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, werr := writer.Write([]byte("Failed to decode id token rawClaims")) + if werr != nil { + util.LogErr(err, "Failed to write response") + } +} + // validateOIDCAllowedDomains checks that if AllowedDomains is provided, // that the authenticated principal ends with @. func validateOIDCAllowedDomains( diff --git a/hscontrol/oidc_test.go b/hscontrol/oidc_test.go new file mode 100644 index 0000000..e7e4c68 --- /dev/null +++ b/hscontrol/oidc_test.go @@ -0,0 +1,164 @@ +package hscontrol + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-jose/go-jose/v3" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "net/http/httptest" + "reflect" + "testing" +) + +func Test_extractIDTokenClaims(t *testing.T) { + tests := []verificationTest{ + { + name: "default claim names", + idToken: `{"iss":"https://foo", "email": "foo@bar.baz", "groups": ["group1", "group2"]}`, + cfg: types.OIDCConfig{ + EmailClaim: "email", + GroupsClaim: "groups", + }, + want: &IDTokenClaims{ + Groups: []string{"group1", "group2"}, + Email: "foo@bar.baz", + }, + wantErr: false, + }, + { + name: "custom claim names", + idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz", "https://foo.baz/groups": ["group3", "group4"]}`, + cfg: types.OIDCConfig{ + EmailClaim: "my_custom_claim", + GroupsClaim: "https://foo.baz/groups", + }, + want: &IDTokenClaims{ + Groups: []string{"group3", "group4"}, + Email: "foo@bar.baz", + }, + wantErr: false, + }, + { + name: "group claim not present", + idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz"}`, + cfg: types.OIDCConfig{ + EmailClaim: "my_custom_claim", + GroupsClaim: "https://foo.baz/groups", + }, + want: &IDTokenClaims{ + Email: "foo@bar.baz", + }, + wantErr: false, + }, + { + name: "email claim not present", + idToken: `{"iss":"https://foo", "groups": ["group1", "group2"]}`, + cfg: types.OIDCConfig{ + EmailClaim: "email", + GroupsClaim: "groups", + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + token, err := tt.getToken(t) + if err != nil { + t.Errorf("could not parse the token: %v", err) + + return + } + + if !tt.wantErr { + assert.Equal(t, 200, recorder.Result().StatusCode) + assert.Empty(t, recorder.Result().Header) + } + + got, err := extractIDTokenClaims(recorder, tt.cfg, token) + if (err != nil) != tt.wantErr { + t.Errorf("extractIDTokenClaims() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("extractIDTokenClaims() got = %v, want %v", got, tt.want) + + return + } + }) + } +} + +type signingKey struct { + keyID string + key interface{} + pub interface{} + alg jose.SignatureAlgorithm +} + +// sign creates a JWS using the private key from the provided payload. +func (s *signingKey) sign(t testing.TB, payload []byte) string { + privKey := &jose.JSONWebKey{Key: s.key, Algorithm: string(s.alg), KeyID: s.keyID} + + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: s.alg, Key: privKey}, nil) + if err != nil { + t.Fatal(err) + } + jws, err := signer.Sign(payload) + if err != nil { + t.Fatal(err) + } + + data, err := jws.CompactSerialize() + if err != nil { + t.Fatal(err) + } + + return data +} + +type verificationTest struct { + name string + idToken string + cfg types.OIDCConfig + want *IDTokenClaims + wantErr bool +} + +func newRSAKey(t testing.TB) *signingKey { + priv, err := rsa.GenerateKey(rand.Reader, 1028) + if err != nil { + t.Fatal(err) + } + + return &signingKey{"", priv, priv.Public(), jose.RS256} +} + +func (v verificationTest) getToken(t *testing.T) (*oidc.IDToken, error) { + key := newRSAKey(t) + token := key.sign(t, []byte(v.idToken)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + verifier := oidc.NewVerifier( + "https://foo", + &oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.pub}}, + &oidc.Config{ + SkipClientIDCheck: true, + SkipExpiryCheck: true, + SkipIssuerCheck: true, + InsecureSkipSignatureCheck: true, + }, + ) + + return verifier.Verify(ctx, token) +} diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index e78795d..fa76ca3 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -102,6 +102,8 @@ type OIDCConfig struct { AllowedDomains []string AllowedUsers []string AllowedGroups []string + GroupsClaim string + EmailClaim string StripEmaildomain bool Expiry time.Duration UseExpiryFromToken bool @@ -185,6 +187,8 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("oidc.strip_email_domain", true) viper.SetDefault("oidc.only_start_if_oidc_is_available", true) viper.SetDefault("oidc.expiry", "180d") + viper.SetDefault("oidc.groups_claim", "groups") + viper.SetDefault("oidc.email_claim", "email") viper.SetDefault("oidc.use_expiry_from_token", false) viper.SetDefault("logtail.enabled", false) @@ -631,6 +635,8 @@ func GetHeadscaleConfig() (*Config, error) { AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"), AllowedUsers: viper.GetStringSlice("oidc.allowed_users"), AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"), + GroupsClaim: viper.GetString("oidc.groups_claim"), + EmailClaim: viper.GetString("oidc.email_claim"), StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), Expiry: func() time.Duration { // if set to 0, we assume no expiry