diff --git a/app.go b/app.go index 3f43d7a..38807c7 100644 --- a/app.go +++ b/app.go @@ -24,7 +24,7 @@ import ( "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/puzpuzpuz/xsync" + "github.com/puzpuzpuz/xsync/v2" zl "github.com/rs/zerolog" "github.com/rs/zerolog/log" "golang.org/x/crypto/acme" @@ -94,7 +94,7 @@ type Headscale struct { aclPolicy *ACLPolicy aclRules []tailcfg.FilterRule - lastStateChange *xsync.MapOf[time.Time] + lastStateChange *xsync.MapOf[string, time.Time] oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -884,7 +884,7 @@ func (h *Headscale) setLastStateChangeToNow() { now := time.Now().UTC() - namespaces, err := h.ListNamespacesStr() + namespaces, err := h.ListNamespaces() if err != nil { log.Error(). Caller(). @@ -893,22 +893,22 @@ func (h *Headscale) setLastStateChangeToNow() { } for _, namespace := range namespaces { - lastStateUpdate.WithLabelValues(namespace, "headscale").Set(float64(now.Unix())) + lastStateUpdate.WithLabelValues(namespace.Name, "headscale").Set(float64(now.Unix())) if h.lastStateChange == nil { h.lastStateChange = xsync.NewMapOf[time.Time]() } - h.lastStateChange.Store(namespace, now) + h.lastStateChange.Store(namespace.Name, now) } } -func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { +func (h *Headscale) getLastStateChange(namespaces ...Namespace) time.Time { times := []time.Time{} // getLastStateChange takes a list of namespaces as a "filter", if no namespaces // are past, then use the entier list of namespaces and look for the last update if len(namespaces) > 0 { for _, namespace := range namespaces { - if lastChange, ok := h.lastStateChange.Load(namespace); ok { + if lastChange, ok := h.lastStateChange.Load(namespace.Name); ok { times = append(times, lastChange) } } diff --git a/flake.nix b/flake.nix index 971d7ea..d990dc6 100644 --- a/flake.nix +++ b/flake.nix @@ -31,7 +31,7 @@ # When updating go.mod or go.sum, a new sha will need to be calculated, # update this if you have a mismatch after doing a change to thos files. - vendorSha256 = "sha256-DosFCSiQ5FURbIrt4NcPGkExc84t2MGMqe9XLxNHdIM="; + vendorSha256 = "sha256-nbPCCqGqBFtfbrCeT2WgtUZ+6DerV/bpYpkXtoRaCHE="; ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"]; }; diff --git a/go.mod b/go.mod index 53f1fd9..40b8a9c 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,6 @@ require ( github.com/prometheus/client_golang v1.13.0 github.com/prometheus/common v0.37.0 github.com/pterm/pterm v0.12.45 - github.com/puzpuzpuz/xsync v1.4.3 github.com/rs/zerolog v1.28.0 github.com/spf13/cobra v1.5.0 github.com/spf13/viper v1.12.0 @@ -117,6 +116,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect + github.com/puzpuzpuz/xsync/v2 v2.0.2 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/rogpeppe/go-internal v1.8.1-0.20211023094830-115ce09fd6b4 // indirect diff --git a/go.sum b/go.sum index 180e7c3..9e86117 100644 --- a/go.sum +++ b/go.sum @@ -846,8 +846,8 @@ github.com/pterm/pterm v0.12.36/go.mod h1:NjiL09hFhT/vWjQHSj1athJpx6H8cjpHXNAK5b github.com/pterm/pterm v0.12.40/go.mod h1:ffwPLwlbXxP+rxT0GsgDTzS3y3rmpAO1NMjUkGTYf8s= github.com/pterm/pterm v0.12.45 h1:5HATKLTDjl9D74b0x7yiHzFI7OADlSXK3yHrJNhRwZE= github.com/pterm/pterm v0.12.45/go.mod h1:hJgLlBafm45w/Hr0dKXxY//POD7CgowhePaG1sdPNBg= -github.com/puzpuzpuz/xsync v1.4.3 h1:nS/Iqc4EnpJ8jm/MzJ+e3MUaP2Ys2mqXeEfoxoU0HaM= -github.com/puzpuzpuz/xsync v1.4.3/go.mod h1:K98BYhX3k1dQ2M63t1YNVDanbwUPmBCAhNmVrrxfiGg= +github.com/puzpuzpuz/xsync/v2 v2.0.2 h1:IpXQ8gGkrnZlLGpJLDmq56sYjNhF88n934Yq5BV5fKw= +github.com/puzpuzpuz/xsync/v2 v2.0.2/go.mod h1:gD2H2krq/w52MfPLE+Uy64TzJDVY7lP2znR9qmR35kU= github.com/quasilyte/go-consistent v0.0.0-20190521200055-c6f3937de18c/go.mod h1:5STLWrekHfjyYwxBRVRXNOSewLJ3PWfDJd1VyTS21fI= github.com/quasilyte/go-ruleguard v0.3.1-0.20210203134552-1b5a410e1cc8/go.mod h1:KsAh3x0e7Fkpgs+Q9pNLS5XpFSvYCEVl5gP9Pp1xp30= github.com/quasilyte/go-ruleguard v0.3.13/go.mod h1:Ul8wwdqR6kBVOCt2dipDBkE+T6vAV/iixkrKuRTN1oQ= diff --git a/integration/cli_test.go b/integration/cli_test.go new file mode 100644 index 0000000..43e5551 --- /dev/null +++ b/integration/cli_test.go @@ -0,0 +1,377 @@ +package integration + +import ( + "encoding/json" + "sort" + "testing" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/stretchr/testify/assert" +) + +func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error { + str, err := headscale.Execute(command) + if err != nil { + return err + } + + err = json.Unmarshal([]byte(str), result) + if err != nil { + return err + } + + return nil +} + +func TestNamespaceCommand(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + scenario, err := NewScenario() + assert.NoError(t, err) + + spec := map[string]int{ + "namespace1": 0, + "namespace2": 0, + } + + err = scenario.CreateHeadscaleEnv(spec) + assert.NoError(t, err) + + var listNamespaces []v1.Namespace + err = executeAndUnmarshal(scenario.Headscale(), + []string{ + "headscale", + "namespaces", + "list", + "--output", + "json", + }, + &listNamespaces, + ) + assert.NoError(t, err) + + result := []string{listNamespaces[0].Name, listNamespaces[1].Name} + sort.Strings(result) + + assert.Equal( + t, + []string{"namespace1", "namespace2"}, + result, + ) + + _, err = scenario.Headscale().Execute( + []string{ + "headscale", + "namespaces", + "rename", + "--output", + "json", + "namespace2", + "newname", + }, + ) + assert.NoError(t, err) + + var listAfterRenameNamespaces []v1.Namespace + err = executeAndUnmarshal(scenario.Headscale(), + []string{ + "headscale", + "namespaces", + "list", + "--output", + "json", + }, + &listAfterRenameNamespaces, + ) + assert.NoError(t, err) + + result = []string{listAfterRenameNamespaces[0].Name, listAfterRenameNamespaces[1].Name} + sort.Strings(result) + + assert.Equal( + t, + []string{"namespace1", "newname"}, + result, + ) + + err = scenario.Shutdown() + assert.NoError(t, err) +} + +func TestPreAuthKeyCommand(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + namespace := "preauthkeyspace" + count := 3 + + scenario, err := NewScenario() + assert.NoError(t, err) + + spec := map[string]int{ + namespace: 0, + } + + err = scenario.CreateHeadscaleEnv(spec) + assert.NoError(t, err) + + keys := make([]*v1.PreAuthKey, count) + assert.NoError(t, err) + + for index := 0; index < count; index++ { + var preAuthKey v1.PreAuthKey + err := executeAndUnmarshal( + scenario.Headscale(), + []string{ + "headscale", + "preauthkeys", + "--namespace", + namespace, + "create", + "--reusable", + "--expiration", + "24h", + "--output", + "json", + "--tags", + "tag:test1,tag:test2", + }, + &preAuthKey, + ) + assert.NoError(t, err) + + keys[index] = &preAuthKey + } + + assert.Len(t, keys, 3) + + var listedPreAuthKeys []v1.PreAuthKey + err = executeAndUnmarshal( + scenario.Headscale(), + []string{ + "headscale", + "preauthkeys", + "--namespace", + namespace, + "list", + "--output", + "json", + }, + &listedPreAuthKeys, + ) + assert.NoError(t, err) + + // There is one key created by "scenario.CreateHeadscaleEnv" + assert.Len(t, listedPreAuthKeys, 4) + + assert.Equal( + t, + []string{keys[0].Id, keys[1].Id, keys[2].Id}, + []string{listedPreAuthKeys[1].Id, listedPreAuthKeys[2].Id, listedPreAuthKeys[3].Id}, + ) + + assert.NotEmpty(t, listedPreAuthKeys[1].Key) + assert.NotEmpty(t, listedPreAuthKeys[2].Key) + assert.NotEmpty(t, listedPreAuthKeys[3].Key) + + assert.True(t, listedPreAuthKeys[1].Expiration.AsTime().After(time.Now())) + assert.True(t, listedPreAuthKeys[2].Expiration.AsTime().After(time.Now())) + assert.True(t, listedPreAuthKeys[3].Expiration.AsTime().After(time.Now())) + + assert.True( + t, + listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)), + ) + assert.True( + t, + listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)), + ) + assert.True( + t, + listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)), + ) + + for index := range listedPreAuthKeys { + if index == 0 { + continue + } + + assert.Equal(t, listedPreAuthKeys[index].AclTags, []string{"tag:test1", "tag:test2"}) + } + + // Test key expiry + _, err = scenario.Headscale().Execute( + []string{ + "headscale", + "preauthkeys", + "--namespace", + namespace, + "expire", + listedPreAuthKeys[1].Key, + }, + ) + assert.NoError(t, err) + + var listedPreAuthKeysAfterExpire []v1.PreAuthKey + err = executeAndUnmarshal( + scenario.Headscale(), + []string{ + "headscale", + "preauthkeys", + "--namespace", + namespace, + "list", + "--output", + "json", + }, + &listedPreAuthKeysAfterExpire, + ) + assert.NoError(t, err) + + assert.True(t, listedPreAuthKeysAfterExpire[1].Expiration.AsTime().Before(time.Now())) + assert.True(t, listedPreAuthKeysAfterExpire[2].Expiration.AsTime().After(time.Now())) + assert.True(t, listedPreAuthKeysAfterExpire[3].Expiration.AsTime().After(time.Now())) + + err = scenario.Shutdown() + assert.NoError(t, err) +} + +func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + namespace := "pre-auth-key-without-exp-namespace" + + scenario, err := NewScenario() + assert.NoError(t, err) + + spec := map[string]int{ + namespace: 0, + } + + err = scenario.CreateHeadscaleEnv(spec) + assert.NoError(t, err) + + var preAuthKey v1.PreAuthKey + err = executeAndUnmarshal( + scenario.Headscale(), + []string{ + "headscale", + "preauthkeys", + "--namespace", + namespace, + "create", + "--reusable", + "--output", + "json", + }, + &preAuthKey, + ) + assert.NoError(t, err) + + var listedPreAuthKeys []v1.PreAuthKey + err = executeAndUnmarshal( + scenario.Headscale(), + []string{ + "headscale", + "preauthkeys", + "--namespace", + namespace, + "list", + "--output", + "json", + }, + &listedPreAuthKeys, + ) + assert.NoError(t, err) + + // There is one key created by "scenario.CreateHeadscaleEnv" + assert.Len(t, listedPreAuthKeys, 2) + + assert.True(t, listedPreAuthKeys[1].Expiration.AsTime().After(time.Now())) + assert.True( + t, + listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Minute*70)), + ) + + err = scenario.Shutdown() + assert.NoError(t, err) +} + +func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + namespace := "pre-auth-key-reus-ephm-namespace" + + scenario, err := NewScenario() + assert.NoError(t, err) + + spec := map[string]int{ + namespace: 0, + } + + err = scenario.CreateHeadscaleEnv(spec) + assert.NoError(t, err) + + var preAuthReusableKey v1.PreAuthKey + err = executeAndUnmarshal( + scenario.Headscale(), + []string{ + "headscale", + "preauthkeys", + "--namespace", + namespace, + "create", + "--reusable=true", + "--output", + "json", + }, + &preAuthReusableKey, + ) + assert.NoError(t, err) + + var preAuthEphemeralKey v1.PreAuthKey + err = executeAndUnmarshal( + scenario.Headscale(), + []string{ + "headscale", + "preauthkeys", + "--namespace", + namespace, + "create", + "--ephemeral=true", + "--output", + "json", + }, + &preAuthEphemeralKey, + ) + assert.NoError(t, err) + + assert.True(t, preAuthEphemeralKey.GetEphemeral()) + assert.False(t, preAuthEphemeralKey.GetReusable()) + + var listedPreAuthKeys []v1.PreAuthKey + err = executeAndUnmarshal( + scenario.Headscale(), + []string{ + "headscale", + "preauthkeys", + "--namespace", + namespace, + "list", + "--output", + "json", + }, + &listedPreAuthKeys, + ) + assert.NoError(t, err) + + // There is one key created by "scenario.CreateHeadscaleEnv" + assert.Len(t, listedPreAuthKeys, 3) + + err = scenario.Shutdown() + assert.NoError(t, err) +} diff --git a/integration/control.go b/integration/control.go index 58a4661..33a687c 100644 --- a/integration/control.go +++ b/integration/control.go @@ -6,6 +6,7 @@ import ( type ControlServer interface { Shutdown() error + Execute(command []string) (string, error) GetHealthEndpoint() string GetEndpoint() string WaitForReady() error diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index a6372a5..7db8233 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -103,6 +103,29 @@ func (t *HeadscaleInContainer) Shutdown() error { return t.pool.Purge(t.container) } +func (t *HeadscaleInContainer) Execute( + command []string, +) (string, error) { + log.Println("command", command) + log.Printf("running command for %s\n", t.hostname) + stdout, stderr, err := dockertestutil.ExecuteCommand( + t.container, + command, + []string{}, + ) + if err != nil { + log.Printf("command stderr: %s\n", stderr) + + return "", err + } + + if stdout != "" { + log.Printf("command stdout: %s\n", stdout) + } + + return stdout, nil +} + func (t *HeadscaleInContainer) GetIP() string { return t.container.GetIPInNetwork(t.network) } diff --git a/namespaces.go b/namespaces.go index ac8913f..d169881 100644 --- a/namespaces.go +++ b/namespaces.go @@ -148,21 +148,6 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) { return namespaces, nil } -func (h *Headscale) ListNamespacesStr() ([]string, error) { - namespaces, err := h.ListNamespaces() - if err != nil { - return []string{}, err - } - - namespaceStrs := make([]string, len(namespaces)) - - for index, namespace := range namespaces { - namespaceStrs[index] = namespace.Name - } - - return namespaceStrs, nil -} - // ListMachinesInNamespace gets all the nodes in a given namespace. func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) { err := CheckForFQDNRules(name) diff --git a/protocol_common_poll.go b/protocol_common_poll.go index e697286..246d0ce 100644 --- a/protocol_common_poll.go +++ b/protocol_common_poll.go @@ -449,7 +449,7 @@ func (h *Headscale) pollNetMapStream( Bool("noise", isNoise). Str("machine", machine.Hostname). Time("last_successful_update", lastUpdate). - Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)). + Time("last_state_change", h.getLastStateChange(machine.Namespace)). Msgf("There has been updates since the last successful update to %s", machine.Hostname) data, err := h.getMapResponseData(mapRequest, machine, false) if err != nil { @@ -549,7 +549,7 @@ func (h *Headscale) pollNetMapStream( Bool("noise", isNoise). Str("machine", machine.Hostname). Time("last_successful_update", lastUpdate). - Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)). + Time("last_state_change", h.getLastStateChange(machine.Namespace)). Msgf("%s is up to date", machine.Hostname) }