Merge remote-tracking branch 'fen4o/add-oidc-claim-names'
This commit is contained in:
commit
1592c8edc9
5 changed files with 266 additions and 15 deletions
|
@ -36,6 +36,7 @@ Use error group handling to ensure tests actually pass [#1535](https://github.co
|
||||||
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
|
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
|
||||||
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
|
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
|
||||||
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
|
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
|
||||||
|
Add `oidc.groups_claim` and `oidc.email_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)
|
||||||
|
|
||||||
## 0.22.3 (2023-05-12)
|
## 0.22.3 (2023-05-12)
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,11 @@ oidc:
|
||||||
# It resolves environment variables, making integration to systemd's
|
# It resolves environment variables, making integration to systemd's
|
||||||
# `LoadCredential` straightforward:
|
# `LoadCredential` straightforward:
|
||||||
#client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret"
|
#client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret"
|
||||||
|
# If provided, the name of a custom OIDC claim for specifying user groups.
|
||||||
|
# The claim value is expected to be a string or array of strings.
|
||||||
|
groups_claim: groups
|
||||||
|
# The OIDC claim to use as the email.
|
||||||
|
email_claim: email
|
||||||
|
|
||||||
# Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
|
# Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
|
||||||
# parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".
|
# parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
|
@ -40,14 +41,47 @@ var (
|
||||||
errOIDCInvalidNodeState = errors.New(
|
errOIDCInvalidNodeState = errors.New(
|
||||||
"requested node state key expired before authorisation completed",
|
"requested node state key expired before authorisation completed",
|
||||||
)
|
)
|
||||||
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
|
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
|
||||||
|
errOIDCEmailClaimMissing = errors.New("email claim missing from ID Token")
|
||||||
|
errOIDCUsernameClaimMissing = errors.New("username claim missing from ID Token")
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDTokenClaims struct {
|
type IDTokenClaims struct {
|
||||||
Name string `json:"name,omitempty"`
|
// in some cases the groups might be a single value and not a list
|
||||||
Groups []string `json:"groups,omitempty"`
|
Groups stringOrArray
|
||||||
Email string `json:"email"`
|
Email string
|
||||||
Username string `json:"preferred_username,omitempty"`
|
Username string
|
||||||
|
}
|
||||||
|
|
||||||
|
type stringOrArray []string
|
||||||
|
|
||||||
|
func (s *stringOrArray) UnmarshalJSON(b []byte) error {
|
||||||
|
var a []string
|
||||||
|
if err := json.Unmarshal(b, &a); err == nil {
|
||||||
|
*s = a
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var str string
|
||||||
|
if err := json.Unmarshal(b, &str); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*s = []string{str}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type rawClaims map[string]json.RawMessage
|
||||||
|
|
||||||
|
func (c rawClaims) unmarshalClaim(name string, v interface{}) error {
|
||||||
|
val, ok := c[name]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("claim not present")
|
||||||
|
}
|
||||||
|
return json.Unmarshal([]byte(val), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c rawClaims) hasClaim(name string) bool {
|
||||||
|
_, ok := c[name]
|
||||||
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) initOIDC() error {
|
func (h *Headscale) initOIDC() error {
|
||||||
|
@ -215,7 +249,7 @@ func (h *Headscale) OIDCCallback(
|
||||||
// return
|
// return
|
||||||
// }
|
// }
|
||||||
|
|
||||||
claims, err := extractIDTokenClaims(writer, idToken)
|
claims, err := extractIDTokenClaims(writer, h.cfg.OIDC, idToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -360,25 +394,63 @@ func (h *Headscale) verifyIDTokenForOIDCCallback(
|
||||||
|
|
||||||
func extractIDTokenClaims(
|
func extractIDTokenClaims(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
|
cfg types.OIDCConfig,
|
||||||
idToken *oidc.IDToken,
|
idToken *oidc.IDToken,
|
||||||
) (*IDTokenClaims, error) {
|
) (*IDTokenClaims, error) {
|
||||||
var claims IDTokenClaims
|
var claims IDTokenClaims
|
||||||
if err := idToken.Claims(&claims); err != nil {
|
var rawClaims rawClaims
|
||||||
util.LogErr(err, "Failed to decode id token claims")
|
if err := idToken.Claims(&rawClaims); err != nil {
|
||||||
|
handleClaimError(writer, err)
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, werr := writer.Write([]byte("Failed to decode id token claims"))
|
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !rawClaims.hasClaim(cfg.EmailClaim) {
|
||||||
|
handleClaimError(writer, errOIDCEmailClaimMissing)
|
||||||
|
|
||||||
|
return nil, errOIDCEmailClaimMissing
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rawClaims.unmarshalClaim(cfg.EmailClaim, &claims.Email); err != nil {
|
||||||
|
handleClaimError(writer, err)
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rawClaims.hasClaim(cfg.UsernameClaim) {
|
||||||
|
handleClaimError(writer, errOIDCUsernameClaimMissing)
|
||||||
|
|
||||||
|
return nil, errOIDCUsernameClaimMissing
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rawClaims.unmarshalClaim(cfg.UsernameClaim, &claims.Username); err != nil {
|
||||||
|
handleClaimError(writer, err)
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if rawClaims.hasClaim(cfg.GroupsClaim) {
|
||||||
|
if err := rawClaims.unmarshalClaim(cfg.GroupsClaim, &claims.Groups); err != nil {
|
||||||
|
handleClaimError(writer, err)
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &claims, nil
|
return &claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleClaimError(writer http.ResponseWriter, err error) {
|
||||||
|
util.LogErr(err, "Failed to decode id token rawClaims")
|
||||||
|
|
||||||
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
writer.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, werr := writer.Write([]byte("Failed to decode id token rawClaims"))
|
||||||
|
if werr != nil {
|
||||||
|
util.LogErr(err, "Failed to write response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
|
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
|
||||||
// that the authenticated principal ends with @<alloweddomain>.
|
// that the authenticated principal ends with @<alloweddomain>.
|
||||||
func validateOIDCAllowedDomains(
|
func validateOIDCAllowedDomains(
|
||||||
|
|
164
hscontrol/oidc_test.go
Normal file
164
hscontrol/oidc_test.go
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
package hscontrol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/go-jose/go-jose/v3"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_extractIDTokenClaims(t *testing.T) {
|
||||||
|
tests := []verificationTest{
|
||||||
|
{
|
||||||
|
name: "default claim names",
|
||||||
|
idToken: `{"iss":"https://foo", "email": "foo@bar.baz", "groups": ["group1", "group2"]}`,
|
||||||
|
cfg: types.OIDCConfig{
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
},
|
||||||
|
want: &IDTokenClaims{
|
||||||
|
Groups: []string{"group1", "group2"},
|
||||||
|
Email: "foo@bar.baz",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom claim names",
|
||||||
|
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz", "https://foo.baz/groups": ["group3", "group4"]}`,
|
||||||
|
cfg: types.OIDCConfig{
|
||||||
|
EmailClaim: "my_custom_claim",
|
||||||
|
GroupsClaim: "https://foo.baz/groups",
|
||||||
|
},
|
||||||
|
want: &IDTokenClaims{
|
||||||
|
Groups: []string{"group3", "group4"},
|
||||||
|
Email: "foo@bar.baz",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "group claim not present",
|
||||||
|
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz"}`,
|
||||||
|
cfg: types.OIDCConfig{
|
||||||
|
EmailClaim: "my_custom_claim",
|
||||||
|
GroupsClaim: "https://foo.baz/groups",
|
||||||
|
},
|
||||||
|
want: &IDTokenClaims{
|
||||||
|
Email: "foo@bar.baz",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "email claim not present",
|
||||||
|
idToken: `{"iss":"https://foo", "groups": ["group1", "group2"]}`,
|
||||||
|
cfg: types.OIDCConfig{
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
token, err := tt.getToken(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("could not parse the token: %v", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantErr {
|
||||||
|
assert.Equal(t, 200, recorder.Result().StatusCode)
|
||||||
|
assert.Empty(t, recorder.Result().Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := extractIDTokenClaims(recorder, tt.cfg, token)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("extractIDTokenClaims() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("extractIDTokenClaims() got = %v, want %v", got, tt.want)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type signingKey struct {
|
||||||
|
keyID string
|
||||||
|
key interface{}
|
||||||
|
pub interface{}
|
||||||
|
alg jose.SignatureAlgorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
// sign creates a JWS using the private key from the provided payload.
|
||||||
|
func (s *signingKey) sign(t testing.TB, payload []byte) string {
|
||||||
|
privKey := &jose.JSONWebKey{Key: s.key, Algorithm: string(s.alg), KeyID: s.keyID}
|
||||||
|
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: s.alg, Key: privKey}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
jws, err := signer.Sign(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := jws.CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
type verificationTest struct {
|
||||||
|
name string
|
||||||
|
idToken string
|
||||||
|
cfg types.OIDCConfig
|
||||||
|
want *IDTokenClaims
|
||||||
|
wantErr bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRSAKey(t testing.TB) *signingKey {
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 1028)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &signingKey{"", priv, priv.Public(), jose.RS256}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v verificationTest) getToken(t *testing.T) (*oidc.IDToken, error) {
|
||||||
|
key := newRSAKey(t)
|
||||||
|
token := key.sign(t, []byte(v.idToken))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
verifier := oidc.NewVerifier(
|
||||||
|
"https://foo",
|
||||||
|
&oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.pub}},
|
||||||
|
&oidc.Config{
|
||||||
|
SkipClientIDCheck: true,
|
||||||
|
SkipExpiryCheck: true,
|
||||||
|
SkipIssuerCheck: true,
|
||||||
|
InsecureSkipSignatureCheck: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return verifier.Verify(ctx, token)
|
||||||
|
}
|
|
@ -102,6 +102,9 @@ type OIDCConfig struct {
|
||||||
AllowedDomains []string
|
AllowedDomains []string
|
||||||
AllowedUsers []string
|
AllowedUsers []string
|
||||||
AllowedGroups []string
|
AllowedGroups []string
|
||||||
|
GroupsClaim string
|
||||||
|
EmailClaim string
|
||||||
|
UsernameClaim string
|
||||||
StripEmaildomain bool
|
StripEmaildomain bool
|
||||||
UseUsernameClaim bool
|
UseUsernameClaim bool
|
||||||
Expiry time.Duration
|
Expiry time.Duration
|
||||||
|
@ -187,6 +190,9 @@ func LoadConfig(path string, isFile bool) error {
|
||||||
viper.SetDefault("oidc.use_username_claim", false)
|
viper.SetDefault("oidc.use_username_claim", false)
|
||||||
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
||||||
viper.SetDefault("oidc.expiry", "180d")
|
viper.SetDefault("oidc.expiry", "180d")
|
||||||
|
viper.SetDefault("oidc.groups_claim", "groups")
|
||||||
|
viper.SetDefault("oidc.email_claim", "email")
|
||||||
|
viper.SetDefault("oidc.username_claim", "preferred_username")
|
||||||
viper.SetDefault("oidc.use_expiry_from_token", false)
|
viper.SetDefault("oidc.use_expiry_from_token", false)
|
||||||
|
|
||||||
viper.SetDefault("logtail.enabled", false)
|
viper.SetDefault("logtail.enabled", false)
|
||||||
|
@ -634,6 +640,9 @@ func GetHeadscaleConfig() (*Config, error) {
|
||||||
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
||||||
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
||||||
UseUsernameClaim: viper.GetBool("oidc.use_username_claim"),
|
UseUsernameClaim: viper.GetBool("oidc.use_username_claim"),
|
||||||
|
GroupsClaim: viper.GetString("oidc.groups_claim"),
|
||||||
|
EmailClaim: viper.GetString("oidc.email_claim"),
|
||||||
|
UsernameClaim: viper.GetString("oidc.username_claim"),
|
||||||
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
||||||
Expiry: func() time.Duration {
|
Expiry: func() time.Duration {
|
||||||
// if set to 0, we assume no expiry
|
// if set to 0, we assume no expiry
|
||||||
|
|
Loading…
Reference in a new issue