Compare commits

...

4 commits

Author SHA1 Message Date
Ty
1592c8edc9
Merge remote-tracking branch 'fen4o/add-oidc-claim-names' 2023-11-12 21:53:23 -07:00
Ty
5fc56adf11
Merge remote-tracking branch 'meson800/oidc_username_claim' 2023-11-12 16:34:35 -07:00
fen4o
9d58489903 Add OIDC claim names options
Some identity providers (auth0 for example) do not allow to set the
groups claims and administrators must use custom claims names and add
them in the id token.

This commit adds the following configuration options:

- `oidc.groups_claim` to set the groups claim name
- `oidc.email_claim` to set the email claim name

All claims default to the previous values for backwards compatibility.

The groups claim can now also accept `[]string` or `string` as some
providers might return only a string response instead of array.
2023-11-08 16:00:07 +02:00
Christopher Johnstone
205a008013 Allow use of the preferred_username OIDC claim
Previously, Headscale would only use the `email` OIDC
claim to set the Headscale user. In certain cases
(self-hosted SSO), it may be useful to instead use the
`preferred_username` to set the Headscale username.
This also closes #938.

This adds a config setting to use this claim instead.
The OIDC docs have been updated to include this entry as well.
In addition, this adds an Authelia OIDC example to the docs.

Added OIDC claim integration tests.

Updated the MockOIDC wrapper to take an environment variable that
lets you set the username/email claims to return.

Added two integration tests, TestOIDCEmailGrant and
TestOIDCUsernameGrant, which check the username by checking the FQDN of
clients.

Updated the HTML template shown after OIDC login to show whatever
username is used, based on the Headscale settings.
2023-10-29 16:55:20 -04:00
10 changed files with 734 additions and 36 deletions

View file

@ -0,0 +1,65 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestOIDCEmailGrant
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestOIDCEmailGrant:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestOIDCEmailGrant
if: steps.changed-files.outputs.any_changed == 'true'
run: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-tags ts2019 \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestOIDCEmailGrant$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -0,0 +1,65 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestOIDCUsernameGrant
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestOIDCUsernameGrant:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestOIDCUsernameGrant
if: steps.changed-files.outputs.any_changed == 'true'
run: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-tags ts2019 \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestOIDCUsernameGrant$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -29,12 +29,14 @@ API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553
### Changes ### Changes
Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484) Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287)
SSH support [#1487](https://github.com/juanfont/headscale/pull/1487) SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492) State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460) Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460)
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
Add `oidc.groups_claim` and `oidc.email_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)
## 0.22.3 (2023-05-12) ## 0.22.3 (2023-05-12)

View file

@ -4,7 +4,9 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"regexp"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/oauth2-proxy/mockoidc" "github.com/oauth2-proxy/mockoidc"
@ -64,6 +66,28 @@ func mockOIDC() error {
accessTTL = newTTL accessTTL = newTTL
} }
mockUsers := os.Getenv("MOCKOIDC_USERS")
users := []mockoidc.User{}
if mockUsers != "" {
userStrings := strings.Split(mockUsers, ",")
userRe := regexp.MustCompile(`^\s*(?P<username>\S+)\s*<(?P<email>\S+@\S+)>\s*$`)
for _, v := range userStrings {
match := userRe.FindStringSubmatch(v)
if match != nil {
// Use the default mockoidc claims for other entries
users = append(users, &mockoidc.MockUser{
Subject: "1234567890",
Email: match[2],
PreferredUsername: match[1],
Phone: "555-987-6543",
Address: "123 Main Street",
Groups: []string{"engineering", "design"},
EmailVerified: true,
})
}
}
}
log.Info().Msgf("Access token TTL: %s", accessTTL) log.Info().Msgf("Access token TTL: %s", accessTTL)
port, err := strconv.Atoi(portStr) port, err := strconv.Atoi(portStr)
@ -71,7 +95,7 @@ func mockOIDC() error {
return err return err
} }
mock, err := getMockOIDC(clientID, clientSecret) mock, err := getMockOIDC(clientID, clientSecret, users)
if err != nil { if err != nil {
return err return err
} }
@ -93,7 +117,7 @@ func mockOIDC() error {
return nil return nil
} }
func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, error) { func getMockOIDC(clientID string, clientSecret string, users []mockoidc.User) (*mockoidc.MockOIDC, error) {
keypair, err := mockoidc.NewKeypair(nil) keypair, err := mockoidc.NewKeypair(nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -111,5 +135,9 @@ func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, erro
ErrorQueue: &mockoidc.ErrorQueue{}, ErrorQueue: &mockoidc.ErrorQueue{},
} }
for _, v := range users {
mock.QueueUser(v)
}
return &mock, nil return &mock, nil
} }

View file

@ -304,6 +304,16 @@ unix_socket_permission: "0770"
# allowed_users: # allowed_users:
# - alice@example.com # - alice@example.com
# #
# # By default, Headscale will use the OIDC email address claim to determine the username.
# # OIDC also returns a `preferred_username` claim.
# #
# # If `use_username_claim` is set to `true`, then the `preferred_username` claim will
# # be used instead to set the Headscale username.
# # If `use_username_claim` is set to `false`, then the `email` claim will be used
# # to derive the Headscale username (as modified by the `strip_email_domain` entry).
#
# use_username_claim: false
#
# # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. # # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed.
# # This will transform `first-name.last-name@example.com` to the user `first-name.last-name` # # This will transform `first-name.last-name@example.com` to the user `first-name.last-name`
# # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following # # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following

View file

@ -24,6 +24,11 @@ oidc:
# It resolves environment variables, making integration to systemd's # It resolves environment variables, making integration to systemd's
# `LoadCredential` straightforward: # `LoadCredential` straightforward:
#client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret" #client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret"
# If provided, the name of a custom OIDC claim for specifying user groups.
# The claim value is expected to be a string or array of strings.
groups_claim: groups
# The OIDC claim to use as the email.
email_claim: email
# Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query # Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
# parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email". # parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".
@ -43,6 +48,16 @@ oidc:
allowed_users: allowed_users:
- alice@example.com - alice@example.com
# By default, Headscale will use the OIDC email address claim to determine the username.
# OIDC also returns a `preferred_username` claim.
#
# If `use_username_claim` is set to `true`, then the `preferred_username` claim will
# be used instead to set the Headscale username.
# If `use_username_claim` is set to `false`, then the `email` claim will be used
# to derive the Headscale username (as modified by the `strip_email_domain` entry).
use_username_claim: false
# If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed.
# This will transform `first-name.last-name@example.com` to the user `first-name.last-name` # This will transform `first-name.last-name@example.com` to the user `first-name.last-name`
# If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following
@ -170,3 +185,48 @@ oidc:
``` ```
You can also use `allowed_domains` and `allowed_users` to restrict the users who can authenticate. You can also use `allowed_domains` and `allowed_users` to restrict the users who can authenticate.
## Authelia Example
In order to integrate Headscale with your Authelia instance, you need to generate a client secret add your Headscale instance as a client.
First, generate a client secret. If you are running Authelia inside docker, prepend `docker-compose exec <authelia_container_name>` before these commands:
```shell
authelia crypto hash generate pbkdf2 --variant sha512 --random --random.length 72
```
This will return two strings, a "Random Password" which you will fill into Headscale, and a "Digest" you will fill into Authelia.
In your Authelia configuration, add Headscale under the client section:
```yaml
clients:
- id: headscale
description: Headscale
secret: "DIGEST_STRING_FROM_ABOVE"
public: false
authorization_policy: two_factor
redirect_uris:
- https://your.headscale.domain/oidc/callback
scopes:
- openid
- profile
- email
- groups
```
In your Headscale `config.yaml`, edit the config under `oidc`, filling in the `client_id` to match the `id` line in the Authelia config and filling in `client_secret` from the "Random Password" output.
You may want to tune the `expiry`, `only_start_if_oidc_available`, and other entries. The following are only the required entries.
```yaml
oidc:
issuer: "https://your.authelia.domain"
client_id: "headscale"
client_secret: "RANDOM_PASSWORD_STRING_FROM_ABOVE"
scope: ["openid", "profile", "email", "groups"]
allowed_groups:
- authelia_groups_you_want_to_limit
```
In particular, you may want to set `use_username_claim: true` to use Authelia's `preferred_username` grant to set Headscale usernames.

View file

@ -6,6 +6,7 @@ import (
"crypto/rand" "crypto/rand"
_ "embed" _ "embed"
"encoding/hex" "encoding/hex"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"html/template" "html/template"
@ -41,13 +42,46 @@ var (
"requested node state key expired before authorisation completed", "requested node state key expired before authorisation completed",
) )
errOIDCNodeKeyMissing = errors.New("could not get node key from cache") errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
errOIDCEmailClaimMissing = errors.New("email claim missing from ID Token")
errOIDCUsernameClaimMissing = errors.New("username claim missing from ID Token")
) )
type IDTokenClaims struct { type IDTokenClaims struct {
Name string `json:"name,omitempty"` // in some cases the groups might be a single value and not a list
Groups []string `json:"groups,omitempty"` Groups stringOrArray
Email string `json:"email"` Email string
Username string `json:"preferred_username,omitempty"` Username string
}
type stringOrArray []string
func (s *stringOrArray) UnmarshalJSON(b []byte) error {
var a []string
if err := json.Unmarshal(b, &a); err == nil {
*s = a
return nil
}
var str string
if err := json.Unmarshal(b, &str); err != nil {
return err
}
*s = []string{str}
return nil
}
type rawClaims map[string]json.RawMessage
func (c rawClaims) unmarshalClaim(name string, v interface{}) error {
val, ok := c[name]
if !ok {
return fmt.Errorf("claim not present")
}
return json.Unmarshal([]byte(val), v)
}
func (c rawClaims) hasClaim(name string) bool {
_, ok := c[name]
return ok
} }
func (h *Headscale) initOIDC() error { func (h *Headscale) initOIDC() error {
@ -215,7 +249,7 @@ func (h *Headscale) OIDCCallback(
// return // return
// } // }
claims, err := extractIDTokenClaims(writer, idToken) claims, err := extractIDTokenClaims(writer, h.cfg.OIDC, idToken)
if err != nil { if err != nil {
return return
} }
@ -242,7 +276,12 @@ func (h *Headscale) OIDCCallback(
return return
} }
userName, err := getUserName(writer, claims, h.cfg.OIDC.StripEmaildomain) userName, err := getUserName(
writer,
claims,
h.cfg.OIDC.UseUsernameClaim,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil { if err != nil {
return return
} }
@ -259,7 +298,7 @@ func (h *Headscale) OIDCCallback(
return return
} }
content, err := renderOIDCCallbackTemplate(writer, claims) content, err := renderOIDCCallbackTemplate(writer, userName)
if err != nil { if err != nil {
return return
} }
@ -355,25 +394,63 @@ func (h *Headscale) verifyIDTokenForOIDCCallback(
func extractIDTokenClaims( func extractIDTokenClaims(
writer http.ResponseWriter, writer http.ResponseWriter,
cfg types.OIDCConfig,
idToken *oidc.IDToken, idToken *oidc.IDToken,
) (*IDTokenClaims, error) { ) (*IDTokenClaims, error) {
var claims IDTokenClaims var claims IDTokenClaims
if err := idToken.Claims(&claims); err != nil { var rawClaims rawClaims
util.LogErr(err, "Failed to decode id token claims") if err := idToken.Claims(&rawClaims); err != nil {
handleClaimError(writer, err)
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token claims"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err return nil, err
} }
if !rawClaims.hasClaim(cfg.EmailClaim) {
handleClaimError(writer, errOIDCEmailClaimMissing)
return nil, errOIDCEmailClaimMissing
}
if err := rawClaims.unmarshalClaim(cfg.EmailClaim, &claims.Email); err != nil {
handleClaimError(writer, err)
return nil, err
}
if !rawClaims.hasClaim(cfg.UsernameClaim) {
handleClaimError(writer, errOIDCUsernameClaimMissing)
return nil, errOIDCUsernameClaimMissing
}
if err := rawClaims.unmarshalClaim(cfg.UsernameClaim, &claims.Username); err != nil {
handleClaimError(writer, err)
return nil, err
}
if rawClaims.hasClaim(cfg.GroupsClaim) {
if err := rawClaims.unmarshalClaim(cfg.GroupsClaim, &claims.Groups); err != nil {
handleClaimError(writer, err)
return nil, err
}
}
return &claims, nil return &claims, nil
} }
func handleClaimError(writer http.ResponseWriter, err error) {
util.LogErr(err, "Failed to decode id token rawClaims")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token rawClaims"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
}
// validateOIDCAllowedDomains checks that if AllowedDomains is provided, // validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>. // that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains( func validateOIDCAllowedDomains(
@ -539,9 +616,19 @@ func (h *Headscale) validateNodeForOIDCCallback(
Str("expiresAt", fmt.Sprintf("%v", expiry)). Str("expiresAt", fmt.Sprintf("%v", expiry)).
Msg("successfully refreshed node") Msg("successfully refreshed node")
userName, err := getUserName(
writer,
claims,
h.cfg.OIDC.UseUsernameClaim,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
userName = "unknown"
}
var content bytes.Buffer var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email, User: userName,
Verb: "Reauthenticated", Verb: "Reauthenticated",
}); err != nil { }); err != nil {
log.Error(). log.Error().
@ -576,18 +663,30 @@ func (h *Headscale) validateNodeForOIDCCallback(
func getUserName( func getUserName(
writer http.ResponseWriter, writer http.ResponseWriter,
claims *IDTokenClaims, claims *IDTokenClaims,
useUsernameClaim bool,
stripEmaildomain bool, stripEmaildomain bool,
) (string, error) { ) (string, error) {
var claim string
if useUsernameClaim {
claim = claims.Username
} else {
claim = claims.Email
}
userName, err := util.NormalizeToFQDNRules( userName, err := util.NormalizeToFQDNRules(
claims.Email, claim,
stripEmaildomain, stripEmaildomain,
) )
if err != nil { if err != nil {
util.LogErr(err, "couldn't normalize email") var friendlyErrMsg string
if useUsernameClaim {
friendlyErrMsg = "couldn't normalize username (preferred_username OIDC claim)"
} else {
friendlyErrMsg = "couldn't normalize username (email OIDC claim)"
}
log.Error().Err(err).Caller().Msgf(friendlyErrMsg)
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("couldn't normalize email")) _, werr := writer.Write([]byte(friendlyErrMsg))
if werr != nil { if werr != nil {
util.LogErr(err, "Failed to write response") util.LogErr(err, "Failed to write response")
} }
@ -668,11 +767,11 @@ func (h *Headscale) registerNodeForOIDCCallback(
func renderOIDCCallbackTemplate( func renderOIDCCallbackTemplate(
writer http.ResponseWriter, writer http.ResponseWriter,
claims *IDTokenClaims, user string,
) (*bytes.Buffer, error) { ) (*bytes.Buffer, error) {
var content bytes.Buffer var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email, User: user,
Verb: "Authenticated", Verb: "Authenticated",
}); err != nil { }); err != nil {
log.Error(). log.Error().

164
hscontrol/oidc_test.go Normal file
View file

@ -0,0 +1,164 @@
package hscontrol
import (
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v3"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"reflect"
"testing"
)
func Test_extractIDTokenClaims(t *testing.T) {
tests := []verificationTest{
{
name: "default claim names",
idToken: `{"iss":"https://foo", "email": "foo@bar.baz", "groups": ["group1", "group2"]}`,
cfg: types.OIDCConfig{
EmailClaim: "email",
GroupsClaim: "groups",
},
want: &IDTokenClaims{
Groups: []string{"group1", "group2"},
Email: "foo@bar.baz",
},
wantErr: false,
},
{
name: "custom claim names",
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz", "https://foo.baz/groups": ["group3", "group4"]}`,
cfg: types.OIDCConfig{
EmailClaim: "my_custom_claim",
GroupsClaim: "https://foo.baz/groups",
},
want: &IDTokenClaims{
Groups: []string{"group3", "group4"},
Email: "foo@bar.baz",
},
wantErr: false,
},
{
name: "group claim not present",
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz"}`,
cfg: types.OIDCConfig{
EmailClaim: "my_custom_claim",
GroupsClaim: "https://foo.baz/groups",
},
want: &IDTokenClaims{
Email: "foo@bar.baz",
},
wantErr: false,
},
{
name: "email claim not present",
idToken: `{"iss":"https://foo", "groups": ["group1", "group2"]}`,
cfg: types.OIDCConfig{
EmailClaim: "email",
GroupsClaim: "groups",
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
token, err := tt.getToken(t)
if err != nil {
t.Errorf("could not parse the token: %v", err)
return
}
if !tt.wantErr {
assert.Equal(t, 200, recorder.Result().StatusCode)
assert.Empty(t, recorder.Result().Header)
}
got, err := extractIDTokenClaims(recorder, tt.cfg, token)
if (err != nil) != tt.wantErr {
t.Errorf("extractIDTokenClaims() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("extractIDTokenClaims() got = %v, want %v", got, tt.want)
return
}
})
}
}
type signingKey struct {
keyID string
key interface{}
pub interface{}
alg jose.SignatureAlgorithm
}
// sign creates a JWS using the private key from the provided payload.
func (s *signingKey) sign(t testing.TB, payload []byte) string {
privKey := &jose.JSONWebKey{Key: s.key, Algorithm: string(s.alg), KeyID: s.keyID}
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: s.alg, Key: privKey}, nil)
if err != nil {
t.Fatal(err)
}
jws, err := signer.Sign(payload)
if err != nil {
t.Fatal(err)
}
data, err := jws.CompactSerialize()
if err != nil {
t.Fatal(err)
}
return data
}
type verificationTest struct {
name string
idToken string
cfg types.OIDCConfig
want *IDTokenClaims
wantErr bool
}
func newRSAKey(t testing.TB) *signingKey {
priv, err := rsa.GenerateKey(rand.Reader, 1028)
if err != nil {
t.Fatal(err)
}
return &signingKey{"", priv, priv.Public(), jose.RS256}
}
func (v verificationTest) getToken(t *testing.T) (*oidc.IDToken, error) {
key := newRSAKey(t)
token := key.sign(t, []byte(v.idToken))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
verifier := oidc.NewVerifier(
"https://foo",
&oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.pub}},
&oidc.Config{
SkipClientIDCheck: true,
SkipExpiryCheck: true,
SkipIssuerCheck: true,
InsecureSkipSignatureCheck: true,
},
)
return verifier.Verify(ctx, token)
}

View file

@ -102,7 +102,11 @@ type OIDCConfig struct {
AllowedDomains []string AllowedDomains []string
AllowedUsers []string AllowedUsers []string
AllowedGroups []string AllowedGroups []string
GroupsClaim string
EmailClaim string
UsernameClaim string
StripEmaildomain bool StripEmaildomain bool
UseUsernameClaim bool
Expiry time.Duration Expiry time.Duration
UseExpiryFromToken bool UseExpiryFromToken bool
} }
@ -183,8 +187,12 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
viper.SetDefault("oidc.strip_email_domain", true) viper.SetDefault("oidc.strip_email_domain", true)
viper.SetDefault("oidc.use_username_claim", false)
viper.SetDefault("oidc.only_start_if_oidc_is_available", true) viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
viper.SetDefault("oidc.expiry", "180d") viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.groups_claim", "groups")
viper.SetDefault("oidc.email_claim", "email")
viper.SetDefault("oidc.username_claim", "preferred_username")
viper.SetDefault("oidc.use_expiry_from_token", false) viper.SetDefault("oidc.use_expiry_from_token", false)
viper.SetDefault("logtail.enabled", false) viper.SetDefault("logtail.enabled", false)
@ -631,6 +639,10 @@ func GetHeadscaleConfig() (*Config, error) {
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"), AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"), AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"), AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
UseUsernameClaim: viper.GetBool("oidc.use_username_claim"),
GroupsClaim: viper.GetString("oidc.groups_claim"),
EmailClaim: viper.GetString("oidc.email_claim"),
UsernameClaim: viper.GetString("oidc.username_claim"),
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
Expiry: func() time.Duration { Expiry: func() time.Duration {
// if set to 0, we assume no expiry // if set to 0, we assume no expiry

View file

@ -11,6 +11,7 @@ import (
"net/http" "net/http"
"net/netip" "net/netip"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@ -38,6 +39,181 @@ type AuthOIDCScenario struct {
mockOIDC *dockertest.Resource mockOIDC *dockertest.Resource
} }
func TestOIDCUsernameGrant(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
baseScenario, err := NewScenario()
if err != nil {
t.Errorf("failed to create scenario: %s", err)
}
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.Shutdown()
spec := map[string]int{
"user1": len(MustTestVersions),
}
users := make([]string, len(MustTestVersions))
for i := range users {
users[i] = "test-user <test-email@example.com>"
}
userStr := strings.Join(users, ", ")
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, userStr)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "false",
"HEADSCALE_OIDC_USE_USERNAME_CLAIM": "true",
}
err = scenario.CreateHeadscaleEnv(
spec,
hsic.WithTestName("oidcauthping"),
hsic.WithConfigEnv(oidcMap),
hsic.WithHostnameAsServerURL(),
hsic.WithFileInContainer(
"/tmp/hs_client_oidc_secret",
[]byte(oidcConfig.ClientSecret),
),
)
if err != nil {
t.Errorf("failed to create headscale environment: %s", err)
}
allClients, err := scenario.ListTailscaleClients()
if err != nil {
t.Errorf("failed to get clients: %s", err)
}
// Check that clients are registered under the right username
for _, client := range allClients {
fqdn, err := client.FQDN()
if err != nil {
t.Errorf("Unable to get client FQDN: %s", err)
}
if !strings.HasSuffix(fqdn, "test-user.headscale.net") {
t.Errorf(
"Client registered with unexpected username. Client FQDN: %s",
fqdn,
)
}
}
allIps, err := scenario.ListTailscaleClientsIPs()
if err != nil {
t.Errorf("failed to get clients: %s", err)
}
err = scenario.WaitForTailscaleSync()
if err != nil {
t.Errorf("failed wait for tailscale clients to be in sync: %s", err)
}
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
func TestOIDCEmailGrant(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
baseScenario, err := NewScenario()
if err != nil {
t.Errorf("failed to create scenario: %s", err)
}
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.Shutdown()
spec := map[string]int{
"user1": len(MustTestVersions),
}
users := make([]string, len(MustTestVersions))
for i := range users {
users[i] = "test-user <test-email@example.com>"
}
userStr := strings.Join(users, ", ")
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, userStr)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "true",
"HEADSCALE_OIDC_USE_USERNAME_CLAIM": "false",
}
err = scenario.CreateHeadscaleEnv(
spec,
hsic.WithTestName("oidcauthping"),
hsic.WithConfigEnv(oidcMap),
hsic.WithHostnameAsServerURL(),
hsic.WithFileInContainer(
"/tmp/hs_client_oidc_secret",
[]byte(oidcConfig.ClientSecret),
),
)
if err != nil {
t.Errorf("failed to create headscale environment: %s", err)
}
allClients, err := scenario.ListTailscaleClients()
if err != nil {
t.Errorf("failed to get clients: %s", err)
}
// Check that clients are registered under the right username
for _, client := range allClients {
fqdn, err := client.FQDN()
if err != nil {
t.Errorf("Unable to get client FQDN: %s", err)
}
if !strings.HasSuffix(fqdn, "test-email.headscale.net") {
t.Errorf(
"Client registered with unexpected username. Client FQDN: %s",
fqdn,
)
}
}
allIps, err := scenario.ListTailscaleClientsIPs()
if err != nil {
t.Errorf("failed to get clients: %s", err)
}
err = scenario.WaitForTailscaleSync()
if err != nil {
t.Errorf("failed wait for tailscale clients to be in sync: %s", err)
}
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
func TestOIDCAuthenticationPingAll(t *testing.T) { func TestOIDCAuthenticationPingAll(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
@ -54,7 +230,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
"user1": len(MustTestVersions), "user1": len(MustTestVersions),
} }
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL) oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, "")
assertNoErrf(t, "failed to run mock OIDC server: %s", err) assertNoErrf(t, "failed to run mock OIDC server: %s", err)
oidcMap := map[string]string{ oidcMap := map[string]string{
@ -62,7 +238,10 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"CREDENTIALS_DIRECTORY_TEST": "/tmp", "CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain), "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf(
"%t",
oidcConfig.StripEmaildomain,
),
} }
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnv(
@ -70,7 +249,10 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
hsic.WithTestName("oidcauthping"), hsic.WithTestName("oidcauthping"),
hsic.WithConfigEnv(oidcMap), hsic.WithConfigEnv(oidcMap),
hsic.WithHostnameAsServerURL(), hsic.WithHostnameAsServerURL(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), hsic.WithFileInContainer(
"/tmp/hs_client_oidc_secret",
[]byte(oidcConfig.ClientSecret),
),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
@ -112,14 +294,17 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
"user1": 3, "user1": 3,
} }
oidcConfig, err := scenario.runMockOIDC(shortAccessTTL) oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, "")
assertNoErrf(t, "failed to run mock OIDC server: %s", err) assertNoErrf(t, "failed to run mock OIDC server: %s", err)
oidcMap := map[string]string{ oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret, "HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret,
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain), "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf(
"%t",
oidcConfig.StripEmaildomain,
),
"HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1", "HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1",
} }
@ -145,7 +330,11 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
}) })
success := pingAllHelper(t, allClients, allAddrs) success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d (before expiry)", success, len(allClients)*len(allIps)) t.Logf(
"%d successful pings out of %d (before expiry)",
success,
len(allClients)*len(allIps),
)
// This is not great, but this sadly is a time dependent test, so the // This is not great, but this sadly is a time dependent test, so the
// safe thing to do is wait out the whole TTL time before checking if // safe thing to do is wait out the whole TTL time before checking if
@ -191,7 +380,10 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
return nil return nil
} }
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) { func (s *AuthOIDCScenario) runMockOIDC(
accessTTL time.Duration,
users string,
) (*types.OIDCConfig, error) {
port, err := dockertestutil.RandomFreeHostPort() port, err := dockertestutil.RandomFreeHostPort()
if err != nil { if err != nil {
log.Fatalf("could not find an open port: %s", err) log.Fatalf("could not find an open port: %s", err)
@ -215,6 +407,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
fmt.Sprintf("MOCKOIDC_PORT=%d", port), fmt.Sprintf("MOCKOIDC_PORT=%d", port),
"MOCKOIDC_CLIENT_ID=superclient", "MOCKOIDC_CLIENT_ID=superclient",
"MOCKOIDC_CLIENT_SECRET=supersecret", "MOCKOIDC_CLIENT_SECRET=supersecret",
fmt.Sprintf("MOCKOIDC_USERS=%s", users),
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
}, },
} }