Merge remote-tracking branch 'meson800/oidc_username_claim'

This commit is contained in:
Tyler Beckman 2023-11-12 16:34:35 -07:00
commit 5fc56adf11
Signed by: Ty
GPG key ID: 2813440C772555A4
9 changed files with 468 additions and 21 deletions

View file

@ -0,0 +1,65 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestOIDCEmailGrant
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestOIDCEmailGrant:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestOIDCEmailGrant
if: steps.changed-files.outputs.any_changed == 'true'
run: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-tags ts2019 \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestOIDCEmailGrant$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -0,0 +1,65 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestOIDCUsernameGrant
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestOIDCUsernameGrant:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestOIDCUsernameGrant
if: steps.changed-files.outputs.any_changed == 'true'
run: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-tags ts2019 \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestOIDCUsernameGrant$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -29,6 +29,7 @@ API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553
### Changes ### Changes
Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484) Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287)
SSH support [#1487](https://github.com/juanfont/headscale/pull/1487) SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492) State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460) Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460)

View file

@ -4,7 +4,9 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"regexp"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/oauth2-proxy/mockoidc" "github.com/oauth2-proxy/mockoidc"
@ -64,6 +66,28 @@ func mockOIDC() error {
accessTTL = newTTL accessTTL = newTTL
} }
mockUsers := os.Getenv("MOCKOIDC_USERS")
users := []mockoidc.User{}
if mockUsers != "" {
userStrings := strings.Split(mockUsers, ",")
userRe := regexp.MustCompile(`^\s*(?P<username>\S+)\s*<(?P<email>\S+@\S+)>\s*$`)
for _, v := range userStrings {
match := userRe.FindStringSubmatch(v)
if match != nil {
// Use the default mockoidc claims for other entries
users = append(users, &mockoidc.MockUser{
Subject: "1234567890",
Email: match[2],
PreferredUsername: match[1],
Phone: "555-987-6543",
Address: "123 Main Street",
Groups: []string{"engineering", "design"},
EmailVerified: true,
})
}
}
}
log.Info().Msgf("Access token TTL: %s", accessTTL) log.Info().Msgf("Access token TTL: %s", accessTTL)
port, err := strconv.Atoi(portStr) port, err := strconv.Atoi(portStr)
@ -71,7 +95,7 @@ func mockOIDC() error {
return err return err
} }
mock, err := getMockOIDC(clientID, clientSecret) mock, err := getMockOIDC(clientID, clientSecret, users)
if err != nil { if err != nil {
return err return err
} }
@ -93,7 +117,7 @@ func mockOIDC() error {
return nil return nil
} }
func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, error) { func getMockOIDC(clientID string, clientSecret string, users []mockoidc.User) (*mockoidc.MockOIDC, error) {
keypair, err := mockoidc.NewKeypair(nil) keypair, err := mockoidc.NewKeypair(nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -111,5 +135,9 @@ func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, erro
ErrorQueue: &mockoidc.ErrorQueue{}, ErrorQueue: &mockoidc.ErrorQueue{},
} }
for _, v := range users {
mock.QueueUser(v)
}
return &mock, nil return &mock, nil
} }

View file

@ -304,6 +304,16 @@ unix_socket_permission: "0770"
# allowed_users: # allowed_users:
# - alice@example.com # - alice@example.com
# #
# # By default, Headscale will use the OIDC email address claim to determine the username.
# # OIDC also returns a `preferred_username` claim.
# #
# # If `use_username_claim` is set to `true`, then the `preferred_username` claim will
# # be used instead to set the Headscale username.
# # If `use_username_claim` is set to `false`, then the `email` claim will be used
# # to derive the Headscale username (as modified by the `strip_email_domain` entry).
#
# use_username_claim: false
#
# # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. # # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed.
# # This will transform `first-name.last-name@example.com` to the user `first-name.last-name` # # This will transform `first-name.last-name@example.com` to the user `first-name.last-name`
# # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following # # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following

View file

@ -43,6 +43,16 @@ oidc:
allowed_users: allowed_users:
- alice@example.com - alice@example.com
# By default, Headscale will use the OIDC email address claim to determine the username.
# OIDC also returns a `preferred_username` claim.
#
# If `use_username_claim` is set to `true`, then the `preferred_username` claim will
# be used instead to set the Headscale username.
# If `use_username_claim` is set to `false`, then the `email` claim will be used
# to derive the Headscale username (as modified by the `strip_email_domain` entry).
use_username_claim: false
# If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed.
# This will transform `first-name.last-name@example.com` to the user `first-name.last-name` # This will transform `first-name.last-name@example.com` to the user `first-name.last-name`
# If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following
@ -170,3 +180,48 @@ oidc:
``` ```
You can also use `allowed_domains` and `allowed_users` to restrict the users who can authenticate. You can also use `allowed_domains` and `allowed_users` to restrict the users who can authenticate.
## Authelia Example
In order to integrate Headscale with your Authelia instance, you need to generate a client secret add your Headscale instance as a client.
First, generate a client secret. If you are running Authelia inside docker, prepend `docker-compose exec <authelia_container_name>` before these commands:
```shell
authelia crypto hash generate pbkdf2 --variant sha512 --random --random.length 72
```
This will return two strings, a "Random Password" which you will fill into Headscale, and a "Digest" you will fill into Authelia.
In your Authelia configuration, add Headscale under the client section:
```yaml
clients:
- id: headscale
description: Headscale
secret: "DIGEST_STRING_FROM_ABOVE"
public: false
authorization_policy: two_factor
redirect_uris:
- https://your.headscale.domain/oidc/callback
scopes:
- openid
- profile
- email
- groups
```
In your Headscale `config.yaml`, edit the config under `oidc`, filling in the `client_id` to match the `id` line in the Authelia config and filling in `client_secret` from the "Random Password" output.
You may want to tune the `expiry`, `only_start_if_oidc_available`, and other entries. The following are only the required entries.
```yaml
oidc:
issuer: "https://your.authelia.domain"
client_id: "headscale"
client_secret: "RANDOM_PASSWORD_STRING_FROM_ABOVE"
scope: ["openid", "profile", "email", "groups"]
allowed_groups:
- authelia_groups_you_want_to_limit
```
In particular, you may want to set `use_username_claim: true` to use Authelia's `preferred_username` grant to set Headscale usernames.

View file

@ -242,7 +242,12 @@ func (h *Headscale) OIDCCallback(
return return
} }
userName, err := getUserName(writer, claims, h.cfg.OIDC.StripEmaildomain) userName, err := getUserName(
writer,
claims,
h.cfg.OIDC.UseUsernameClaim,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil { if err != nil {
return return
} }
@ -259,7 +264,7 @@ func (h *Headscale) OIDCCallback(
return return
} }
content, err := renderOIDCCallbackTemplate(writer, claims) content, err := renderOIDCCallbackTemplate(writer, userName)
if err != nil { if err != nil {
return return
} }
@ -539,9 +544,19 @@ func (h *Headscale) validateNodeForOIDCCallback(
Str("expiresAt", fmt.Sprintf("%v", expiry)). Str("expiresAt", fmt.Sprintf("%v", expiry)).
Msg("successfully refreshed node") Msg("successfully refreshed node")
userName, err := getUserName(
writer,
claims,
h.cfg.OIDC.UseUsernameClaim,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
userName = "unknown"
}
var content bytes.Buffer var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email, User: userName,
Verb: "Reauthenticated", Verb: "Reauthenticated",
}); err != nil { }); err != nil {
log.Error(). log.Error().
@ -576,18 +591,30 @@ func (h *Headscale) validateNodeForOIDCCallback(
func getUserName( func getUserName(
writer http.ResponseWriter, writer http.ResponseWriter,
claims *IDTokenClaims, claims *IDTokenClaims,
useUsernameClaim bool,
stripEmaildomain bool, stripEmaildomain bool,
) (string, error) { ) (string, error) {
var claim string
if useUsernameClaim {
claim = claims.Username
} else {
claim = claims.Email
}
userName, err := util.NormalizeToFQDNRules( userName, err := util.NormalizeToFQDNRules(
claims.Email, claim,
stripEmaildomain, stripEmaildomain,
) )
if err != nil { if err != nil {
util.LogErr(err, "couldn't normalize email") var friendlyErrMsg string
if useUsernameClaim {
friendlyErrMsg = "couldn't normalize username (preferred_username OIDC claim)"
} else {
friendlyErrMsg = "couldn't normalize username (email OIDC claim)"
}
log.Error().Err(err).Caller().Msgf(friendlyErrMsg)
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("couldn't normalize email")) _, werr := writer.Write([]byte(friendlyErrMsg))
if werr != nil { if werr != nil {
util.LogErr(err, "Failed to write response") util.LogErr(err, "Failed to write response")
} }
@ -668,11 +695,11 @@ func (h *Headscale) registerNodeForOIDCCallback(
func renderOIDCCallbackTemplate( func renderOIDCCallbackTemplate(
writer http.ResponseWriter, writer http.ResponseWriter,
claims *IDTokenClaims, user string,
) (*bytes.Buffer, error) { ) (*bytes.Buffer, error) {
var content bytes.Buffer var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email, User: user,
Verb: "Authenticated", Verb: "Authenticated",
}); err != nil { }); err != nil {
log.Error(). log.Error().

View file

@ -103,6 +103,7 @@ type OIDCConfig struct {
AllowedUsers []string AllowedUsers []string
AllowedGroups []string AllowedGroups []string
StripEmaildomain bool StripEmaildomain bool
UseUsernameClaim bool
Expiry time.Duration Expiry time.Duration
UseExpiryFromToken bool UseExpiryFromToken bool
} }
@ -183,6 +184,7 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
viper.SetDefault("oidc.strip_email_domain", true) viper.SetDefault("oidc.strip_email_domain", true)
viper.SetDefault("oidc.use_username_claim", false)
viper.SetDefault("oidc.only_start_if_oidc_is_available", true) viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
viper.SetDefault("oidc.expiry", "180d") viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.use_expiry_from_token", false) viper.SetDefault("oidc.use_expiry_from_token", false)
@ -631,6 +633,7 @@ func GetHeadscaleConfig() (*Config, error) {
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"), AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"), AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"), AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
UseUsernameClaim: viper.GetBool("oidc.use_username_claim"),
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
Expiry: func() time.Duration { Expiry: func() time.Duration {
// if set to 0, we assume no expiry // if set to 0, we assume no expiry

View file

@ -11,6 +11,7 @@ import (
"net/http" "net/http"
"net/netip" "net/netip"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@ -38,6 +39,181 @@ type AuthOIDCScenario struct {
mockOIDC *dockertest.Resource mockOIDC *dockertest.Resource
} }
func TestOIDCUsernameGrant(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
baseScenario, err := NewScenario()
if err != nil {
t.Errorf("failed to create scenario: %s", err)
}
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.Shutdown()
spec := map[string]int{
"user1": len(MustTestVersions),
}
users := make([]string, len(MustTestVersions))
for i := range users {
users[i] = "test-user <test-email@example.com>"
}
userStr := strings.Join(users, ", ")
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, userStr)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "false",
"HEADSCALE_OIDC_USE_USERNAME_CLAIM": "true",
}
err = scenario.CreateHeadscaleEnv(
spec,
hsic.WithTestName("oidcauthping"),
hsic.WithConfigEnv(oidcMap),
hsic.WithHostnameAsServerURL(),
hsic.WithFileInContainer(
"/tmp/hs_client_oidc_secret",
[]byte(oidcConfig.ClientSecret),
),
)
if err != nil {
t.Errorf("failed to create headscale environment: %s", err)
}
allClients, err := scenario.ListTailscaleClients()
if err != nil {
t.Errorf("failed to get clients: %s", err)
}
// Check that clients are registered under the right username
for _, client := range allClients {
fqdn, err := client.FQDN()
if err != nil {
t.Errorf("Unable to get client FQDN: %s", err)
}
if !strings.HasSuffix(fqdn, "test-user.headscale.net") {
t.Errorf(
"Client registered with unexpected username. Client FQDN: %s",
fqdn,
)
}
}
allIps, err := scenario.ListTailscaleClientsIPs()
if err != nil {
t.Errorf("failed to get clients: %s", err)
}
err = scenario.WaitForTailscaleSync()
if err != nil {
t.Errorf("failed wait for tailscale clients to be in sync: %s", err)
}
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
func TestOIDCEmailGrant(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
baseScenario, err := NewScenario()
if err != nil {
t.Errorf("failed to create scenario: %s", err)
}
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.Shutdown()
spec := map[string]int{
"user1": len(MustTestVersions),
}
users := make([]string, len(MustTestVersions))
for i := range users {
users[i] = "test-user <test-email@example.com>"
}
userStr := strings.Join(users, ", ")
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, userStr)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "true",
"HEADSCALE_OIDC_USE_USERNAME_CLAIM": "false",
}
err = scenario.CreateHeadscaleEnv(
spec,
hsic.WithTestName("oidcauthping"),
hsic.WithConfigEnv(oidcMap),
hsic.WithHostnameAsServerURL(),
hsic.WithFileInContainer(
"/tmp/hs_client_oidc_secret",
[]byte(oidcConfig.ClientSecret),
),
)
if err != nil {
t.Errorf("failed to create headscale environment: %s", err)
}
allClients, err := scenario.ListTailscaleClients()
if err != nil {
t.Errorf("failed to get clients: %s", err)
}
// Check that clients are registered under the right username
for _, client := range allClients {
fqdn, err := client.FQDN()
if err != nil {
t.Errorf("Unable to get client FQDN: %s", err)
}
if !strings.HasSuffix(fqdn, "test-email.headscale.net") {
t.Errorf(
"Client registered with unexpected username. Client FQDN: %s",
fqdn,
)
}
}
allIps, err := scenario.ListTailscaleClientsIPs()
if err != nil {
t.Errorf("failed to get clients: %s", err)
}
err = scenario.WaitForTailscaleSync()
if err != nil {
t.Errorf("failed wait for tailscale clients to be in sync: %s", err)
}
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
func TestOIDCAuthenticationPingAll(t *testing.T) { func TestOIDCAuthenticationPingAll(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
@ -54,7 +230,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
"user1": len(MustTestVersions), "user1": len(MustTestVersions),
} }
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL) oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, "")
assertNoErrf(t, "failed to run mock OIDC server: %s", err) assertNoErrf(t, "failed to run mock OIDC server: %s", err)
oidcMap := map[string]string{ oidcMap := map[string]string{
@ -62,7 +238,10 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"CREDENTIALS_DIRECTORY_TEST": "/tmp", "CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain), "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf(
"%t",
oidcConfig.StripEmaildomain,
),
} }
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnv(
@ -70,7 +249,10 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
hsic.WithTestName("oidcauthping"), hsic.WithTestName("oidcauthping"),
hsic.WithConfigEnv(oidcMap), hsic.WithConfigEnv(oidcMap),
hsic.WithHostnameAsServerURL(), hsic.WithHostnameAsServerURL(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), hsic.WithFileInContainer(
"/tmp/hs_client_oidc_secret",
[]byte(oidcConfig.ClientSecret),
),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
@ -112,14 +294,17 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
"user1": 3, "user1": 3,
} }
oidcConfig, err := scenario.runMockOIDC(shortAccessTTL) oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, "")
assertNoErrf(t, "failed to run mock OIDC server: %s", err) assertNoErrf(t, "failed to run mock OIDC server: %s", err)
oidcMap := map[string]string{ oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret, "HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret,
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain), "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf(
"%t",
oidcConfig.StripEmaildomain,
),
"HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1", "HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1",
} }
@ -145,7 +330,11 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
}) })
success := pingAllHelper(t, allClients, allAddrs) success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d (before expiry)", success, len(allClients)*len(allIps)) t.Logf(
"%d successful pings out of %d (before expiry)",
success,
len(allClients)*len(allIps),
)
// This is not great, but this sadly is a time dependent test, so the // This is not great, but this sadly is a time dependent test, so the
// safe thing to do is wait out the whole TTL time before checking if // safe thing to do is wait out the whole TTL time before checking if
@ -191,7 +380,10 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
return nil return nil
} }
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) { func (s *AuthOIDCScenario) runMockOIDC(
accessTTL time.Duration,
users string,
) (*types.OIDCConfig, error) {
port, err := dockertestutil.RandomFreeHostPort() port, err := dockertestutil.RandomFreeHostPort()
if err != nil { if err != nil {
log.Fatalf("could not find an open port: %s", err) log.Fatalf("could not find an open port: %s", err)
@ -215,6 +407,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
fmt.Sprintf("MOCKOIDC_PORT=%d", port), fmt.Sprintf("MOCKOIDC_PORT=%d", port),
"MOCKOIDC_CLIENT_ID=superclient", "MOCKOIDC_CLIENT_ID=superclient",
"MOCKOIDC_CLIENT_SECRET=supersecret", "MOCKOIDC_CLIENT_SECRET=supersecret",
fmt.Sprintf("MOCKOIDC_USERS=%s", users),
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
}, },
} }