165 lines
3.7 KiB
Go
165 lines
3.7 KiB
Go
|
package hscontrol
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto"
|
||
|
"crypto/rand"
|
||
|
"crypto/rsa"
|
||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||
|
"github.com/go-jose/go-jose/v3"
|
||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"net/http/httptest"
|
||
|
"reflect"
|
||
|
"testing"
|
||
|
)
|
||
|
|
||
|
func Test_extractIDTokenClaims(t *testing.T) {
|
||
|
tests := []verificationTest{
|
||
|
{
|
||
|
name: "default claim names",
|
||
|
idToken: `{"iss":"https://foo", "email": "foo@bar.baz", "groups": ["group1", "group2"]}`,
|
||
|
cfg: types.OIDCConfig{
|
||
|
EmailClaim: "email",
|
||
|
GroupsClaim: "groups",
|
||
|
},
|
||
|
want: &IDTokenClaims{
|
||
|
Groups: []string{"group1", "group2"},
|
||
|
Email: "foo@bar.baz",
|
||
|
},
|
||
|
wantErr: false,
|
||
|
},
|
||
|
{
|
||
|
name: "custom claim names",
|
||
|
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz", "https://foo.baz/groups": ["group3", "group4"]}`,
|
||
|
cfg: types.OIDCConfig{
|
||
|
EmailClaim: "my_custom_claim",
|
||
|
GroupsClaim: "https://foo.baz/groups",
|
||
|
},
|
||
|
want: &IDTokenClaims{
|
||
|
Groups: []string{"group3", "group4"},
|
||
|
Email: "foo@bar.baz",
|
||
|
},
|
||
|
wantErr: false,
|
||
|
},
|
||
|
{
|
||
|
name: "group claim not present",
|
||
|
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz"}`,
|
||
|
cfg: types.OIDCConfig{
|
||
|
EmailClaim: "my_custom_claim",
|
||
|
GroupsClaim: "https://foo.baz/groups",
|
||
|
},
|
||
|
want: &IDTokenClaims{
|
||
|
Email: "foo@bar.baz",
|
||
|
},
|
||
|
wantErr: false,
|
||
|
},
|
||
|
{
|
||
|
name: "email claim not present",
|
||
|
idToken: `{"iss":"https://foo", "groups": ["group1", "group2"]}`,
|
||
|
cfg: types.OIDCConfig{
|
||
|
EmailClaim: "email",
|
||
|
GroupsClaim: "groups",
|
||
|
},
|
||
|
want: nil,
|
||
|
wantErr: true,
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
tt := tt
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
recorder := httptest.NewRecorder()
|
||
|
token, err := tt.getToken(t)
|
||
|
if err != nil {
|
||
|
t.Errorf("could not parse the token: %v", err)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if !tt.wantErr {
|
||
|
assert.Equal(t, 200, recorder.Result().StatusCode)
|
||
|
assert.Empty(t, recorder.Result().Header)
|
||
|
}
|
||
|
|
||
|
got, err := extractIDTokenClaims(recorder, tt.cfg, token)
|
||
|
if (err != nil) != tt.wantErr {
|
||
|
t.Errorf("extractIDTokenClaims() error = %v, wantErr %v", err, tt.wantErr)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if !reflect.DeepEqual(got, tt.want) {
|
||
|
t.Errorf("extractIDTokenClaims() got = %v, want %v", got, tt.want)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type signingKey struct {
|
||
|
keyID string
|
||
|
key interface{}
|
||
|
pub interface{}
|
||
|
alg jose.SignatureAlgorithm
|
||
|
}
|
||
|
|
||
|
// sign creates a JWS using the private key from the provided payload.
|
||
|
func (s *signingKey) sign(t testing.TB, payload []byte) string {
|
||
|
privKey := &jose.JSONWebKey{Key: s.key, Algorithm: string(s.alg), KeyID: s.keyID}
|
||
|
|
||
|
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: s.alg, Key: privKey}, nil)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
jws, err := signer.Sign(payload)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
data, err := jws.CompactSerialize()
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
return data
|
||
|
}
|
||
|
|
||
|
type verificationTest struct {
|
||
|
name string
|
||
|
idToken string
|
||
|
cfg types.OIDCConfig
|
||
|
want *IDTokenClaims
|
||
|
wantErr bool
|
||
|
}
|
||
|
|
||
|
func newRSAKey(t testing.TB) *signingKey {
|
||
|
priv, err := rsa.GenerateKey(rand.Reader, 1028)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
return &signingKey{"", priv, priv.Public(), jose.RS256}
|
||
|
}
|
||
|
|
||
|
func (v verificationTest) getToken(t *testing.T) (*oidc.IDToken, error) {
|
||
|
key := newRSAKey(t)
|
||
|
token := key.sign(t, []byte(v.idToken))
|
||
|
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
defer cancel()
|
||
|
|
||
|
verifier := oidc.NewVerifier(
|
||
|
"https://foo",
|
||
|
&oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.pub}},
|
||
|
&oidc.Config{
|
||
|
SkipClientIDCheck: true,
|
||
|
SkipExpiryCheck: true,
|
||
|
SkipIssuerCheck: true,
|
||
|
InsecureSkipSignatureCheck: true,
|
||
|
},
|
||
|
)
|
||
|
|
||
|
return verifier.Verify(ctx, token)
|
||
|
}
|