Add OIDC claim names options
Some identity providers (auth0 for example) do not allow to set the groups claims and administrators must use custom claims names and add them in the id token. This commit adds the following configuration options: - `oidc.groups_claim` to set the groups claim name - `oidc.email_claim` to set the email claim name All claims default to the previous values for backwards compatibility. The groups claim can now also accept `[]string` or `string` as some providers might return only a string response instead of array.
This commit is contained in:
parent
2af71c9e31
commit
9d58489903
5 changed files with 249 additions and 15 deletions
|
@ -35,6 +35,7 @@ Use error group handling to ensure tests actually pass [#1535](https://github.co
|
||||||
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
|
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
|
||||||
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
|
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
|
||||||
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
|
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
|
||||||
|
Add `oidc.groups_claim` and `oidc.email_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)
|
||||||
|
|
||||||
## 0.22.3 (2023-05-12)
|
## 0.22.3 (2023-05-12)
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,11 @@ oidc:
|
||||||
# It resolves environment variables, making integration to systemd's
|
# It resolves environment variables, making integration to systemd's
|
||||||
# `LoadCredential` straightforward:
|
# `LoadCredential` straightforward:
|
||||||
#client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret"
|
#client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret"
|
||||||
|
# If provided, the name of a custom OIDC claim for specifying user groups.
|
||||||
|
# The claim value is expected to be a string or array of strings.
|
||||||
|
groups_claim: groups
|
||||||
|
# The OIDC claim to use as the email.
|
||||||
|
email_claim: email
|
||||||
|
|
||||||
# Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
|
# Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
|
||||||
# parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".
|
# parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
|
@ -41,13 +42,44 @@ var (
|
||||||
"requested node state key expired before authorisation completed",
|
"requested node state key expired before authorisation completed",
|
||||||
)
|
)
|
||||||
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
|
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
|
||||||
|
errOIDCEmailClaimMissing = errors.New("email claim missing from ID Token")
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDTokenClaims struct {
|
type IDTokenClaims struct {
|
||||||
Name string `json:"name,omitempty"`
|
// in some cases the groups might be a single value and not a list
|
||||||
Groups []string `json:"groups,omitempty"`
|
Groups stringOrArray
|
||||||
Email string `json:"email"`
|
Email string
|
||||||
Username string `json:"preferred_username,omitempty"`
|
}
|
||||||
|
|
||||||
|
type stringOrArray []string
|
||||||
|
|
||||||
|
func (s *stringOrArray) UnmarshalJSON(b []byte) error {
|
||||||
|
var a []string
|
||||||
|
if err := json.Unmarshal(b, &a); err == nil {
|
||||||
|
*s = a
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var str string
|
||||||
|
if err := json.Unmarshal(b, &str); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*s = []string{str}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type rawClaims map[string]json.RawMessage
|
||||||
|
|
||||||
|
func (c rawClaims) unmarshalClaim(name string, v interface{}) error {
|
||||||
|
val, ok := c[name]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("claim not present")
|
||||||
|
}
|
||||||
|
return json.Unmarshal([]byte(val), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c rawClaims) hasClaim(name string) bool {
|
||||||
|
_, ok := c[name]
|
||||||
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) initOIDC() error {
|
func (h *Headscale) initOIDC() error {
|
||||||
|
@ -215,7 +247,7 @@ func (h *Headscale) OIDCCallback(
|
||||||
// return
|
// return
|
||||||
// }
|
// }
|
||||||
|
|
||||||
claims, err := extractIDTokenClaims(writer, idToken)
|
claims, err := extractIDTokenClaims(writer, h.cfg.OIDC, idToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -355,25 +387,51 @@ func (h *Headscale) verifyIDTokenForOIDCCallback(
|
||||||
|
|
||||||
func extractIDTokenClaims(
|
func extractIDTokenClaims(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
|
cfg types.OIDCConfig,
|
||||||
idToken *oidc.IDToken,
|
idToken *oidc.IDToken,
|
||||||
) (*IDTokenClaims, error) {
|
) (*IDTokenClaims, error) {
|
||||||
var claims IDTokenClaims
|
var claims IDTokenClaims
|
||||||
if err := idToken.Claims(&claims); err != nil {
|
var rawClaims rawClaims
|
||||||
util.LogErr(err, "Failed to decode id token claims")
|
if err := idToken.Claims(&rawClaims); err != nil {
|
||||||
|
handleClaimError(writer, err)
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, werr := writer.Write([]byte("Failed to decode id token claims"))
|
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !rawClaims.hasClaim(cfg.EmailClaim) {
|
||||||
|
handleClaimError(writer, errOIDCEmailClaimMissing)
|
||||||
|
|
||||||
|
return nil, errOIDCEmailClaimMissing
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rawClaims.unmarshalClaim(cfg.EmailClaim, &claims.Email); err != nil {
|
||||||
|
handleClaimError(writer, err)
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if rawClaims.hasClaim(cfg.GroupsClaim) {
|
||||||
|
if err := rawClaims.unmarshalClaim(cfg.GroupsClaim, &claims.Groups); err != nil {
|
||||||
|
handleClaimError(writer, err)
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &claims, nil
|
return &claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleClaimError(writer http.ResponseWriter, err error) {
|
||||||
|
util.LogErr(err, "Failed to decode id token rawClaims")
|
||||||
|
|
||||||
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
writer.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, werr := writer.Write([]byte("Failed to decode id token rawClaims"))
|
||||||
|
if werr != nil {
|
||||||
|
util.LogErr(err, "Failed to write response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
|
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
|
||||||
// that the authenticated principal ends with @<alloweddomain>.
|
// that the authenticated principal ends with @<alloweddomain>.
|
||||||
func validateOIDCAllowedDomains(
|
func validateOIDCAllowedDomains(
|
||||||
|
|
164
hscontrol/oidc_test.go
Normal file
164
hscontrol/oidc_test.go
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
package hscontrol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/go-jose/go-jose/v3"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_extractIDTokenClaims(t *testing.T) {
|
||||||
|
tests := []verificationTest{
|
||||||
|
{
|
||||||
|
name: "default claim names",
|
||||||
|
idToken: `{"iss":"https://foo", "email": "foo@bar.baz", "groups": ["group1", "group2"]}`,
|
||||||
|
cfg: types.OIDCConfig{
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
},
|
||||||
|
want: &IDTokenClaims{
|
||||||
|
Groups: []string{"group1", "group2"},
|
||||||
|
Email: "foo@bar.baz",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom claim names",
|
||||||
|
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz", "https://foo.baz/groups": ["group3", "group4"]}`,
|
||||||
|
cfg: types.OIDCConfig{
|
||||||
|
EmailClaim: "my_custom_claim",
|
||||||
|
GroupsClaim: "https://foo.baz/groups",
|
||||||
|
},
|
||||||
|
want: &IDTokenClaims{
|
||||||
|
Groups: []string{"group3", "group4"},
|
||||||
|
Email: "foo@bar.baz",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "group claim not present",
|
||||||
|
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz"}`,
|
||||||
|
cfg: types.OIDCConfig{
|
||||||
|
EmailClaim: "my_custom_claim",
|
||||||
|
GroupsClaim: "https://foo.baz/groups",
|
||||||
|
},
|
||||||
|
want: &IDTokenClaims{
|
||||||
|
Email: "foo@bar.baz",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "email claim not present",
|
||||||
|
idToken: `{"iss":"https://foo", "groups": ["group1", "group2"]}`,
|
||||||
|
cfg: types.OIDCConfig{
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
token, err := tt.getToken(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("could not parse the token: %v", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantErr {
|
||||||
|
assert.Equal(t, 200, recorder.Result().StatusCode)
|
||||||
|
assert.Empty(t, recorder.Result().Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := extractIDTokenClaims(recorder, tt.cfg, token)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("extractIDTokenClaims() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("extractIDTokenClaims() got = %v, want %v", got, tt.want)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type signingKey struct {
|
||||||
|
keyID string
|
||||||
|
key interface{}
|
||||||
|
pub interface{}
|
||||||
|
alg jose.SignatureAlgorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
// sign creates a JWS using the private key from the provided payload.
|
||||||
|
func (s *signingKey) sign(t testing.TB, payload []byte) string {
|
||||||
|
privKey := &jose.JSONWebKey{Key: s.key, Algorithm: string(s.alg), KeyID: s.keyID}
|
||||||
|
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: s.alg, Key: privKey}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
jws, err := signer.Sign(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := jws.CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
type verificationTest struct {
|
||||||
|
name string
|
||||||
|
idToken string
|
||||||
|
cfg types.OIDCConfig
|
||||||
|
want *IDTokenClaims
|
||||||
|
wantErr bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRSAKey(t testing.TB) *signingKey {
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 1028)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &signingKey{"", priv, priv.Public(), jose.RS256}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v verificationTest) getToken(t *testing.T) (*oidc.IDToken, error) {
|
||||||
|
key := newRSAKey(t)
|
||||||
|
token := key.sign(t, []byte(v.idToken))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
verifier := oidc.NewVerifier(
|
||||||
|
"https://foo",
|
||||||
|
&oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.pub}},
|
||||||
|
&oidc.Config{
|
||||||
|
SkipClientIDCheck: true,
|
||||||
|
SkipExpiryCheck: true,
|
||||||
|
SkipIssuerCheck: true,
|
||||||
|
InsecureSkipSignatureCheck: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return verifier.Verify(ctx, token)
|
||||||
|
}
|
|
@ -102,6 +102,8 @@ type OIDCConfig struct {
|
||||||
AllowedDomains []string
|
AllowedDomains []string
|
||||||
AllowedUsers []string
|
AllowedUsers []string
|
||||||
AllowedGroups []string
|
AllowedGroups []string
|
||||||
|
GroupsClaim string
|
||||||
|
EmailClaim string
|
||||||
StripEmaildomain bool
|
StripEmaildomain bool
|
||||||
Expiry time.Duration
|
Expiry time.Duration
|
||||||
UseExpiryFromToken bool
|
UseExpiryFromToken bool
|
||||||
|
@ -185,6 +187,8 @@ func LoadConfig(path string, isFile bool) error {
|
||||||
viper.SetDefault("oidc.strip_email_domain", true)
|
viper.SetDefault("oidc.strip_email_domain", true)
|
||||||
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
||||||
viper.SetDefault("oidc.expiry", "180d")
|
viper.SetDefault("oidc.expiry", "180d")
|
||||||
|
viper.SetDefault("oidc.groups_claim", "groups")
|
||||||
|
viper.SetDefault("oidc.email_claim", "email")
|
||||||
viper.SetDefault("oidc.use_expiry_from_token", false)
|
viper.SetDefault("oidc.use_expiry_from_token", false)
|
||||||
|
|
||||||
viper.SetDefault("logtail.enabled", false)
|
viper.SetDefault("logtail.enabled", false)
|
||||||
|
@ -631,6 +635,8 @@ func GetHeadscaleConfig() (*Config, error) {
|
||||||
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
|
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
|
||||||
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
||||||
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
||||||
|
GroupsClaim: viper.GetString("oidc.groups_claim"),
|
||||||
|
EmailClaim: viper.GetString("oidc.email_claim"),
|
||||||
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
||||||
Expiry: func() time.Duration {
|
Expiry: func() time.Duration {
|
||||||
// if set to 0, we assume no expiry
|
// if set to 0, we assume no expiry
|
||||||
|
|
Loading…
Reference in a new issue