diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index 165a511..568a2a0 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -16,10 +16,11 @@ const ( errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined") errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined") errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined") - accessTTL = 10 * time.Minute refreshTTL = 60 * time.Minute ) +var accessTTL = 2 * time.Minute + func init() { rootCmd.AddCommand(mockOidcCmd) } @@ -54,6 +55,16 @@ func mockOIDC() error { if portStr == "" { return errMockOidcPortNotDefined } + accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL") + if accessTTLOverride != "" { + newTTL, err := time.ParseDuration(accessTTLOverride) + if err != nil { + return err + } + accessTTL = newTTL + } + + log.Info().Msgf("Access token TTL: %s", accessTTL) port, err := strconv.Atoi(portStr) if err != nil { diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 0c3c901..5c21f56 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -9,8 +9,10 @@ import ( "log" "net" "net/http" + "net/netip" "strconv" "testing" + "time" "github.com/juanfont/headscale" "github.com/juanfont/headscale/integration/dockertestutil" @@ -22,7 +24,7 @@ import ( const ( dockerContextPath = "../." hsicOIDCMockHashLength = 6 - oidcServerPort = 10000 + defaultAccessTTL = 10 * time.Minute ) var errStatusCodeNotOK = errors.New("status code not OK") @@ -50,7 +52,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { "namespace1": len(TailscaleVersions), } - oidcConfig, err := scenario.runMockOIDC() + oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL) if err != nil { t.Errorf("failed to run mock OIDC server: %s", err) } @@ -87,20 +89,76 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { t.Errorf("failed wait for tailscale clients to be in sync: %s", err) } - success := 0 + success := pingAll(t, allClients, allIps) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) - for _, client := range allClients { - for _, ip := range allIps { - err := client.Ping(ip.String()) - if err != nil { - t.Errorf("failed to ping %s from %s: %s", ip, client.Hostname(), err) - } else { - success++ - } - } + err = scenario.Shutdown() + if err != nil { + t.Errorf("failed to tear down scenario: %s", err) + } +} + +func TestOIDCExpireNodes(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + shortAccessTTL := 5 * time.Minute + + baseScenario, err := NewScenario() + if err != nil { + t.Errorf("failed to create scenario: %s", err) } - t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + scenario := AuthOIDCScenario{ + Scenario: baseScenario, + } + + spec := map[string]int{ + "namespace1": len(TailscaleVersions), + } + + oidcConfig, err := scenario.runMockOIDC(shortAccessTTL) + if err != nil { + t.Fatalf("failed to run mock OIDC server: %s", err) + } + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, + "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret, + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain), + } + + err = scenario.CreateHeadscaleEnv( + spec, + hsic.WithTestName("oidcexpirenodes"), + hsic.WithConfigEnv(oidcMap), + hsic.WithHostnameAsServerURL(), + ) + if err != nil { + t.Errorf("failed to create headscale environment: %s", err) + } + + allClients, err := scenario.ListTailscaleClients() + if err != nil { + t.Errorf("failed to get clients: %s", err) + } + + allIps, err := scenario.ListTailscaleClientsIPs() + if err != nil { + t.Errorf("failed to get clients: %s", err) + } + + err = scenario.WaitForTailscaleSync() + if err != nil { + t.Errorf("failed wait for tailscale clients to be in sync: %s", err) + } + + success := pingAll(t, allClients, allIps) + t.Logf("%d successful pings out of %d (before expiry)", success, len(allClients)*len(allIps)) + + // await all nodes being logged out after OIDC token expiry + scenario.WaitForTailscaleLogout() err = scenario.Shutdown() if err != nil { @@ -143,7 +201,13 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( return nil } -func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) { +func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*headscale.OIDCConfig, error) { + port, err := dockertestutil.RandomFreeHostPort() + if err != nil { + log.Fatalf("could not find an open port: %s", err) + } + portNotation := fmt.Sprintf("%d/tcp", port) + hash, _ := headscale.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) hostname := fmt.Sprintf("hs-oidcmock-%s", hash) @@ -151,16 +215,17 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) { mockOidcOptions := &dockertest.RunOptions{ Name: hostname, Cmd: []string{"headscale", "mockoidc"}, - ExposedPorts: []string{"10000/tcp"}, + ExposedPorts: []string{portNotation}, PortBindings: map[docker.Port][]docker.PortBinding{ - "10000/tcp": {{HostPort: "10000"}}, + docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}}, }, Networks: []*dockertest.Network{s.Scenario.network}, Env: []string{ fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname), - "MOCKOIDC_PORT=10000", + fmt.Sprintf("MOCKOIDC_PORT=%d", port), "MOCKOIDC_CLIENT_ID=superclient", "MOCKOIDC_CLIENT_SECRET=supersecret", + fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), }, } @@ -169,7 +234,7 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) { ContextDir: dockerContextPath, } - err := s.pool.RemoveContainerByName(hostname) + err = s.pool.RemoveContainerByName(hostname) if err != nil { return nil, err } @@ -184,11 +249,7 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) { } log.Println("Waiting for headscale mock oidc to be ready for tests") - hostEndpoint := fmt.Sprintf( - "%s:%s", - s.mockOIDC.GetIPInNetwork(s.network), - s.mockOIDC.GetPort(fmt.Sprintf("%d/tcp", oidcServerPort)), - ) + hostEndpoint := fmt.Sprintf("%s:%d", s.mockOIDC.GetIPInNetwork(s.network), port) if err := s.pool.Retry(func() error { oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint) @@ -215,11 +276,11 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) { log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint) return &headscale.OIDCConfig{ - Issuer: fmt.Sprintf("http://%s/oidc", - net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(oidcServerPort))), - ClientID: "superclient", - ClientSecret: "supersecret", - StripEmaildomain: true, + Issuer: fmt.Sprintf("http://%s/oidc", net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(port))), + ClientID: "superclient", + ClientSecret: "supersecret", + StripEmaildomain: true, + OnlyStartIfOIDCIsAvailable: true, }, nil } @@ -292,6 +353,24 @@ func (s *AuthOIDCScenario) runTailscaleUp( return fmt.Errorf("failed to up tailscale node: %w", errNoNamespaceAvailable) } +func pingAll(t *testing.T, clients []TailscaleClient, ips []netip.Addr) int { + t.Helper() + success := 0 + + for _, client := range clients { + for _, ip := range ips { + err := client.Ping(ip.String()) + if err != nil { + t.Errorf("failed to ping %s from %s: %s", ip, client.Hostname(), err) + } else { + success++ + } + } + } + + return success +} + func (s *AuthOIDCScenario) Shutdown() error { err := s.pool.Purge(s.mockOIDC) if err != nil { diff --git a/integration/dockertestutil/network.go b/integration/dockertestutil/network.go index 15c9908..89fdc8e 100644 --- a/integration/dockertestutil/network.go +++ b/integration/dockertestutil/network.go @@ -2,6 +2,7 @@ package dockertestutil import ( "errors" + "net" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" @@ -60,3 +61,20 @@ func AddContainerToNetwork( return nil } + +// RandomFreeHostPort asks the kernel for a free open port that is ready to use. +// (from https://github.com/phayes/freeport) +func RandomFreeHostPort() (int, error) { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + return 0, err + } + + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + return 0, err + } + defer listener.Close() + //nolint:forcetypeassert + return listener.Addr().(*net.TCPAddr).Port, nil +}