diff --git a/CHANGELOG.md b/CHANGELOG.md index 916bf2c..f52c91f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ Use error group handling to ensure tests actually pass [#1535](https://github.co Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) +Add `oidc.groups_claim`, `oidc.email_claim`, and `oidc.username_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594) ## 0.22.3 (2023-05-12) diff --git a/docs/oidc.md b/docs/oidc.md index 689e50c..cef6633 100644 --- a/docs/oidc.md +++ b/docs/oidc.md @@ -24,6 +24,13 @@ oidc: # It resolves environment variables, making integration to systemd's # `LoadCredential` straightforward: #client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret" + # If provided, the name of a custom OIDC claim for specifying user groups. + # The claim value is expected to be a string or array of strings. + groups_claim: groups + # The OIDC claim to use as the email. + email_claim: email + # The OIDC claim to use as the username. + email_claim: preferred_username # Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query # parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email". diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index fa9e75f..30ef1c8 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -6,6 +6,7 @@ import ( "crypto/rand" _ "embed" "encoding/hex" + "encoding/json" "errors" "fmt" "html/template" @@ -40,14 +41,47 @@ var ( errOIDCInvalidNodeState = errors.New( "requested node state key expired before authorisation completed", ) - errOIDCNodeKeyMissing = errors.New("could not get node key from cache") + errOIDCNodeKeyMissing = errors.New("could not get node key from cache") + errOIDCEmailClaimMissing = errors.New("email claim missing from ID Token") + errOIDCUsernameClaimMissing = errors.New("username claim missing from ID Token") ) type IDTokenClaims struct { - Name string `json:"name,omitempty"` - Groups []string `json:"groups,omitempty"` - Email string `json:"email"` - Username string `json:"preferred_username,omitempty"` + // in some cases the groups might be a single value and not a list + Groups stringOrArray + Email string + Username string +} + +type stringOrArray []string + +func (s *stringOrArray) UnmarshalJSON(b []byte) error { + var a []string + if err := json.Unmarshal(b, &a); err == nil { + *s = a + return nil + } + var str string + if err := json.Unmarshal(b, &str); err != nil { + return err + } + *s = []string{str} + return nil +} + +type rawClaims map[string]json.RawMessage + +func (c rawClaims) unmarshalClaim(name string, v interface{}) error { + val, ok := c[name] + if !ok { + return fmt.Errorf("claim not present") + } + return json.Unmarshal([]byte(val), v) +} + +func (c rawClaims) hasClaim(name string) bool { + _, ok := c[name] + return ok } func (h *Headscale) initOIDC() error { @@ -215,7 +249,7 @@ func (h *Headscale) OIDCCallback( // return // } - claims, err := extractIDTokenClaims(writer, idToken) + claims, err := extractIDTokenClaims(writer, h.cfg.OIDC, idToken) if err != nil { return } @@ -360,25 +394,63 @@ func (h *Headscale) verifyIDTokenForOIDCCallback( func extractIDTokenClaims( writer http.ResponseWriter, + cfg types.OIDCConfig, idToken *oidc.IDToken, ) (*IDTokenClaims, error) { var claims IDTokenClaims - if err := idToken.Claims(&claims); err != nil { - util.LogErr(err, "Failed to decode id token claims") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, werr := writer.Write([]byte("Failed to decode id token claims")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } + var rawClaims rawClaims + if err := idToken.Claims(&rawClaims); err != nil { + handleClaimError(writer, err) return nil, err } + if !rawClaims.hasClaim(cfg.EmailClaim) { + handleClaimError(writer, errOIDCEmailClaimMissing) + + return nil, errOIDCEmailClaimMissing + } + + if err := rawClaims.unmarshalClaim(cfg.EmailClaim, &claims.Email); err != nil { + handleClaimError(writer, err) + + return nil, err + } + + if !rawClaims.hasClaim(cfg.UsernameClaim) { + handleClaimError(writer, errOIDCUsernameClaimMissing) + + return nil, errOIDCUsernameClaimMissing + } + + if err := rawClaims.unmarshalClaim(cfg.UsernameClaim, &claims.Username); err != nil { + handleClaimError(writer, err) + + return nil, err + } + + if rawClaims.hasClaim(cfg.GroupsClaim) { + if err := rawClaims.unmarshalClaim(cfg.GroupsClaim, &claims.Groups); err != nil { + handleClaimError(writer, err) + + return nil, err + } + } + return &claims, nil } +func handleClaimError(writer http.ResponseWriter, err error) { + util.LogErr(err, "Failed to decode id token rawClaims") + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, werr := writer.Write([]byte("Failed to decode id token rawClaims")) + if werr != nil { + util.LogErr(err, "Failed to write response") + } +} + // validateOIDCAllowedDomains checks that if AllowedDomains is provided, // that the authenticated principal ends with @. func validateOIDCAllowedDomains( diff --git a/hscontrol/oidc_test.go b/hscontrol/oidc_test.go new file mode 100644 index 0000000..e7e4c68 --- /dev/null +++ b/hscontrol/oidc_test.go @@ -0,0 +1,164 @@ +package hscontrol + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-jose/go-jose/v3" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "net/http/httptest" + "reflect" + "testing" +) + +func Test_extractIDTokenClaims(t *testing.T) { + tests := []verificationTest{ + { + name: "default claim names", + idToken: `{"iss":"https://foo", "email": "foo@bar.baz", "groups": ["group1", "group2"]}`, + cfg: types.OIDCConfig{ + EmailClaim: "email", + GroupsClaim: "groups", + }, + want: &IDTokenClaims{ + Groups: []string{"group1", "group2"}, + Email: "foo@bar.baz", + }, + wantErr: false, + }, + { + name: "custom claim names", + idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz", "https://foo.baz/groups": ["group3", "group4"]}`, + cfg: types.OIDCConfig{ + EmailClaim: "my_custom_claim", + GroupsClaim: "https://foo.baz/groups", + }, + want: &IDTokenClaims{ + Groups: []string{"group3", "group4"}, + Email: "foo@bar.baz", + }, + wantErr: false, + }, + { + name: "group claim not present", + idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz"}`, + cfg: types.OIDCConfig{ + EmailClaim: "my_custom_claim", + GroupsClaim: "https://foo.baz/groups", + }, + want: &IDTokenClaims{ + Email: "foo@bar.baz", + }, + wantErr: false, + }, + { + name: "email claim not present", + idToken: `{"iss":"https://foo", "groups": ["group1", "group2"]}`, + cfg: types.OIDCConfig{ + EmailClaim: "email", + GroupsClaim: "groups", + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + token, err := tt.getToken(t) + if err != nil { + t.Errorf("could not parse the token: %v", err) + + return + } + + if !tt.wantErr { + assert.Equal(t, 200, recorder.Result().StatusCode) + assert.Empty(t, recorder.Result().Header) + } + + got, err := extractIDTokenClaims(recorder, tt.cfg, token) + if (err != nil) != tt.wantErr { + t.Errorf("extractIDTokenClaims() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("extractIDTokenClaims() got = %v, want %v", got, tt.want) + + return + } + }) + } +} + +type signingKey struct { + keyID string + key interface{} + pub interface{} + alg jose.SignatureAlgorithm +} + +// sign creates a JWS using the private key from the provided payload. +func (s *signingKey) sign(t testing.TB, payload []byte) string { + privKey := &jose.JSONWebKey{Key: s.key, Algorithm: string(s.alg), KeyID: s.keyID} + + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: s.alg, Key: privKey}, nil) + if err != nil { + t.Fatal(err) + } + jws, err := signer.Sign(payload) + if err != nil { + t.Fatal(err) + } + + data, err := jws.CompactSerialize() + if err != nil { + t.Fatal(err) + } + + return data +} + +type verificationTest struct { + name string + idToken string + cfg types.OIDCConfig + want *IDTokenClaims + wantErr bool +} + +func newRSAKey(t testing.TB) *signingKey { + priv, err := rsa.GenerateKey(rand.Reader, 1028) + if err != nil { + t.Fatal(err) + } + + return &signingKey{"", priv, priv.Public(), jose.RS256} +} + +func (v verificationTest) getToken(t *testing.T) (*oidc.IDToken, error) { + key := newRSAKey(t) + token := key.sign(t, []byte(v.idToken)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + verifier := oidc.NewVerifier( + "https://foo", + &oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.pub}}, + &oidc.Config{ + SkipClientIDCheck: true, + SkipExpiryCheck: true, + SkipIssuerCheck: true, + InsecureSkipSignatureCheck: true, + }, + ) + + return verifier.Verify(ctx, token) +} diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 983cf34..2d7068a 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -102,6 +102,9 @@ type OIDCConfig struct { AllowedDomains []string AllowedUsers []string AllowedGroups []string + GroupsClaim string + EmailClaim string + UsernameClaim string StripEmaildomain bool UseUsernameClaim bool Expiry time.Duration @@ -187,6 +190,9 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("oidc.use_username_claim", false) viper.SetDefault("oidc.only_start_if_oidc_is_available", true) viper.SetDefault("oidc.expiry", "180d") + viper.SetDefault("oidc.groups_claim", "groups") + viper.SetDefault("oidc.email_claim", "email") + viper.SetDefault("oidc.username_claim", "preferred_username") viper.SetDefault("oidc.use_expiry_from_token", false) viper.SetDefault("logtail.enabled", false) @@ -634,6 +640,9 @@ func GetHeadscaleConfig() (*Config, error) { AllowedUsers: viper.GetStringSlice("oidc.allowed_users"), AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"), UseUsernameClaim: viper.GetBool("oidc.use_username_claim"), + GroupsClaim: viper.GetString("oidc.groups_claim"), + EmailClaim: viper.GetString("oidc.email_claim"), + UsernameClaim: viper.GetString("oidc.username_claim"), StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), Expiry: func() time.Duration { // if set to 0, we assume no expiry