Add OIDC claim names options

Some identity providers (auth0 for example) do not allow to set the
groups claims and administrators must use custom claims names and add
them in the id token.

This commit adds the following configuration options:

- `oidc.groups_claim` to set the groups claim name
- `oidc.email_claim` to set the email claim name

All claims default to the previous values for backwards compatibility.

The groups claim can now also accept `[]string` or `string` as some
providers might return only a string response instead of array.
This commit is contained in:
fen4o 2023-11-08 13:32:47 +02:00
parent 2af71c9e31
commit 9d58489903
5 changed files with 249 additions and 15 deletions

View file

@ -35,6 +35,7 @@ Use error group handling to ensure tests actually pass [#1535](https://github.co
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
Add `oidc.groups_claim` and `oidc.email_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)
## 0.22.3 (2023-05-12) ## 0.22.3 (2023-05-12)

View file

@ -24,6 +24,11 @@ oidc:
# It resolves environment variables, making integration to systemd's # It resolves environment variables, making integration to systemd's
# `LoadCredential` straightforward: # `LoadCredential` straightforward:
#client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret" #client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret"
# If provided, the name of a custom OIDC claim for specifying user groups.
# The claim value is expected to be a string or array of strings.
groups_claim: groups
# The OIDC claim to use as the email.
email_claim: email
# Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query # Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
# parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email". # parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".

View file

@ -6,6 +6,7 @@ import (
"crypto/rand" "crypto/rand"
_ "embed" _ "embed"
"encoding/hex" "encoding/hex"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"html/template" "html/template"
@ -40,14 +41,45 @@ var (
errOIDCInvalidNodeState = errors.New( errOIDCInvalidNodeState = errors.New(
"requested node state key expired before authorisation completed", "requested node state key expired before authorisation completed",
) )
errOIDCNodeKeyMissing = errors.New("could not get node key from cache") errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
errOIDCEmailClaimMissing = errors.New("email claim missing from ID Token")
) )
type IDTokenClaims struct { type IDTokenClaims struct {
Name string `json:"name,omitempty"` // in some cases the groups might be a single value and not a list
Groups []string `json:"groups,omitempty"` Groups stringOrArray
Email string `json:"email"` Email string
Username string `json:"preferred_username,omitempty"` }
type stringOrArray []string
func (s *stringOrArray) UnmarshalJSON(b []byte) error {
var a []string
if err := json.Unmarshal(b, &a); err == nil {
*s = a
return nil
}
var str string
if err := json.Unmarshal(b, &str); err != nil {
return err
}
*s = []string{str}
return nil
}
type rawClaims map[string]json.RawMessage
func (c rawClaims) unmarshalClaim(name string, v interface{}) error {
val, ok := c[name]
if !ok {
return fmt.Errorf("claim not present")
}
return json.Unmarshal([]byte(val), v)
}
func (c rawClaims) hasClaim(name string) bool {
_, ok := c[name]
return ok
} }
func (h *Headscale) initOIDC() error { func (h *Headscale) initOIDC() error {
@ -215,7 +247,7 @@ func (h *Headscale) OIDCCallback(
// return // return
// } // }
claims, err := extractIDTokenClaims(writer, idToken) claims, err := extractIDTokenClaims(writer, h.cfg.OIDC, idToken)
if err != nil { if err != nil {
return return
} }
@ -355,25 +387,51 @@ func (h *Headscale) verifyIDTokenForOIDCCallback(
func extractIDTokenClaims( func extractIDTokenClaims(
writer http.ResponseWriter, writer http.ResponseWriter,
cfg types.OIDCConfig,
idToken *oidc.IDToken, idToken *oidc.IDToken,
) (*IDTokenClaims, error) { ) (*IDTokenClaims, error) {
var claims IDTokenClaims var claims IDTokenClaims
if err := idToken.Claims(&claims); err != nil { var rawClaims rawClaims
util.LogErr(err, "Failed to decode id token claims") if err := idToken.Claims(&rawClaims); err != nil {
handleClaimError(writer, err)
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token claims"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err return nil, err
} }
if !rawClaims.hasClaim(cfg.EmailClaim) {
handleClaimError(writer, errOIDCEmailClaimMissing)
return nil, errOIDCEmailClaimMissing
}
if err := rawClaims.unmarshalClaim(cfg.EmailClaim, &claims.Email); err != nil {
handleClaimError(writer, err)
return nil, err
}
if rawClaims.hasClaim(cfg.GroupsClaim) {
if err := rawClaims.unmarshalClaim(cfg.GroupsClaim, &claims.Groups); err != nil {
handleClaimError(writer, err)
return nil, err
}
}
return &claims, nil return &claims, nil
} }
func handleClaimError(writer http.ResponseWriter, err error) {
util.LogErr(err, "Failed to decode id token rawClaims")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token rawClaims"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
}
// validateOIDCAllowedDomains checks that if AllowedDomains is provided, // validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>. // that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains( func validateOIDCAllowedDomains(

164
hscontrol/oidc_test.go Normal file
View file

@ -0,0 +1,164 @@
package hscontrol
import (
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v3"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"reflect"
"testing"
)
func Test_extractIDTokenClaims(t *testing.T) {
tests := []verificationTest{
{
name: "default claim names",
idToken: `{"iss":"https://foo", "email": "foo@bar.baz", "groups": ["group1", "group2"]}`,
cfg: types.OIDCConfig{
EmailClaim: "email",
GroupsClaim: "groups",
},
want: &IDTokenClaims{
Groups: []string{"group1", "group2"},
Email: "foo@bar.baz",
},
wantErr: false,
},
{
name: "custom claim names",
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz", "https://foo.baz/groups": ["group3", "group4"]}`,
cfg: types.OIDCConfig{
EmailClaim: "my_custom_claim",
GroupsClaim: "https://foo.baz/groups",
},
want: &IDTokenClaims{
Groups: []string{"group3", "group4"},
Email: "foo@bar.baz",
},
wantErr: false,
},
{
name: "group claim not present",
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz"}`,
cfg: types.OIDCConfig{
EmailClaim: "my_custom_claim",
GroupsClaim: "https://foo.baz/groups",
},
want: &IDTokenClaims{
Email: "foo@bar.baz",
},
wantErr: false,
},
{
name: "email claim not present",
idToken: `{"iss":"https://foo", "groups": ["group1", "group2"]}`,
cfg: types.OIDCConfig{
EmailClaim: "email",
GroupsClaim: "groups",
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
token, err := tt.getToken(t)
if err != nil {
t.Errorf("could not parse the token: %v", err)
return
}
if !tt.wantErr {
assert.Equal(t, 200, recorder.Result().StatusCode)
assert.Empty(t, recorder.Result().Header)
}
got, err := extractIDTokenClaims(recorder, tt.cfg, token)
if (err != nil) != tt.wantErr {
t.Errorf("extractIDTokenClaims() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("extractIDTokenClaims() got = %v, want %v", got, tt.want)
return
}
})
}
}
type signingKey struct {
keyID string
key interface{}
pub interface{}
alg jose.SignatureAlgorithm
}
// sign creates a JWS using the private key from the provided payload.
func (s *signingKey) sign(t testing.TB, payload []byte) string {
privKey := &jose.JSONWebKey{Key: s.key, Algorithm: string(s.alg), KeyID: s.keyID}
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: s.alg, Key: privKey}, nil)
if err != nil {
t.Fatal(err)
}
jws, err := signer.Sign(payload)
if err != nil {
t.Fatal(err)
}
data, err := jws.CompactSerialize()
if err != nil {
t.Fatal(err)
}
return data
}
type verificationTest struct {
name string
idToken string
cfg types.OIDCConfig
want *IDTokenClaims
wantErr bool
}
func newRSAKey(t testing.TB) *signingKey {
priv, err := rsa.GenerateKey(rand.Reader, 1028)
if err != nil {
t.Fatal(err)
}
return &signingKey{"", priv, priv.Public(), jose.RS256}
}
func (v verificationTest) getToken(t *testing.T) (*oidc.IDToken, error) {
key := newRSAKey(t)
token := key.sign(t, []byte(v.idToken))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
verifier := oidc.NewVerifier(
"https://foo",
&oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.pub}},
&oidc.Config{
SkipClientIDCheck: true,
SkipExpiryCheck: true,
SkipIssuerCheck: true,
InsecureSkipSignatureCheck: true,
},
)
return verifier.Verify(ctx, token)
}

View file

@ -102,6 +102,8 @@ type OIDCConfig struct {
AllowedDomains []string AllowedDomains []string
AllowedUsers []string AllowedUsers []string
AllowedGroups []string AllowedGroups []string
GroupsClaim string
EmailClaim string
StripEmaildomain bool StripEmaildomain bool
Expiry time.Duration Expiry time.Duration
UseExpiryFromToken bool UseExpiryFromToken bool
@ -185,6 +187,8 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.strip_email_domain", true) viper.SetDefault("oidc.strip_email_domain", true)
viper.SetDefault("oidc.only_start_if_oidc_is_available", true) viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
viper.SetDefault("oidc.expiry", "180d") viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.groups_claim", "groups")
viper.SetDefault("oidc.email_claim", "email")
viper.SetDefault("oidc.use_expiry_from_token", false) viper.SetDefault("oidc.use_expiry_from_token", false)
viper.SetDefault("logtail.enabled", false) viper.SetDefault("logtail.enabled", false)
@ -631,6 +635,8 @@ func GetHeadscaleConfig() (*Config, error) {
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"), AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"), AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"), AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
GroupsClaim: viper.GetString("oidc.groups_claim"),
EmailClaim: viper.GetString("oidc.email_claim"),
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
Expiry: func() time.Duration { Expiry: func() time.Duration {
// if set to 0, we assume no expiry // if set to 0, we assume no expiry