From 7e8bf4bfe5390b2732021fb82368b870a4193ae7 Mon Sep 17 00:00:00 2001 From: Alexander Halbarth Date: Tue, 16 Jan 2024 16:04:03 +0100 Subject: [PATCH 01/13] Add Customization Options to DERP Map entry of integrated DERP server (#1565) Co-authored-by: Alexander Halbarth Co-authored-by: Bela Lemle Co-authored-by: Kristoffer Dalby --- CHANGELOG.md | 1 + config-example.yaml | 10 +++++ hscontrol/app.go | 6 ++- hscontrol/db/routes.go | 2 +- hscontrol/derp/server/derp_server.go | 4 ++ hscontrol/types/config.go | 62 ++++++++++++++++++---------- hscontrol/types/node.go | 2 +- integration/general_test.go | 1 - 8 files changed, 61 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6460d7e..a5441a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) tak Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259) +Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565) ## 0.22.3 (2023-05-12) diff --git a/config-example.yaml b/config-example.yaml index 5105dcd..96a654a 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -94,6 +94,16 @@ derp: # private_key_path: /var/lib/headscale/derp_server_private.key + # This flag can be used, so the DERP map entry for the embedded DERP server is not written automatically, + # it enables the creation of your very own DERP map entry using a locally available file with the parameter DERP.paths + # If you enable the DERP server and set this to false, it is required to add the DERP server to the DERP map using DERP.paths + automatically_add_embedded_derp_region: true + + # For better connection stability (especially when using an Exit-Node and DNS is not working), + # it is possible to optionall add the public IPv4 and IPv6 address to the Derp-Map using: + ipv4: 1.2.3.4 + ipv6: 2001:db8::1 + # List of externally available DERP maps encoded in JSON urls: - https://controlplane.tailscale.com/derpmap/default diff --git a/hscontrol/app.go b/hscontrol/app.go index 5327d6f..75dfdde 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -268,7 +268,7 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { case <-ticker.C: log.Info().Msg("Fetching DERPMap updates") h.DERPMap = derp.GetDERPMap(h.cfg.DERP) - if h.cfg.DERP.ServerEnabled { + if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { region, _ := h.DERPServer.GenerateRegion() h.DERPMap.Regions[region.RegionID] = ®ion } @@ -501,7 +501,9 @@ func (h *Headscale) Serve() error { return err } - h.DERPMap.Regions[region.RegionID] = ®ion + if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { + h.DERPMap.Regions[region.RegionID] = ®ion + } go h.DERPServer.ServeSTUN() } diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 51c7f3b..aed9776 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -349,7 +349,7 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er // SaveNodeRoutes takes a node and updates the database with // the new routes. -// It returns a bool wheter an update should be sent as the +// It returns a bool whether an update should be sent as the // saved route impacts nodes. func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) { hsdb.mu.Lock() diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index 59e4028..ad325c7 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -84,6 +84,8 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { RegionID: d.cfg.ServerRegionID, HostName: host, DERPPort: port, + IPv4: d.cfg.IPv4, + IPv6: d.cfg.IPv6, }, }, } @@ -99,6 +101,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { localDERPregion.Nodes[0].STUNPort = portSTUN log.Info().Caller().Msgf("DERP region: %+v", localDERPregion) + log.Info().Caller().Msgf("DERP Nodes[0]: %+v", localDERPregion.Nodes[0]) return localDERPregion, nil } @@ -208,6 +211,7 @@ func DERPProbeHandler( // The initial implementation is here https://github.com/tailscale/tailscale/pull/1406 // They have a cache, but not clear if that is really necessary at Headscale, uh, scale. // An example implementation is found here https://derp.tailscale.com/bootstrap-dns +// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL func DERPBootstrapDNSHandler( derpMap *tailcfg.DERPMap, ) func(http.ResponseWriter, *http.Request) { diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 4b29c4b..8e61973 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -107,16 +107,19 @@ type OIDCConfig struct { } type DERPConfig struct { - ServerEnabled bool - ServerRegionID int - ServerRegionCode string - ServerRegionName string - ServerPrivateKeyPath string - STUNAddr string - URLs []url.URL - Paths []string - AutoUpdate bool - UpdateFrequency time.Duration + ServerEnabled bool + AutomaticallyAddEmbeddedDerpRegion bool + ServerRegionID int + ServerRegionCode string + ServerRegionName string + ServerPrivateKeyPath string + STUNAddr string + URLs []url.URL + Paths []string + AutoUpdate bool + UpdateFrequency time.Duration + IPv4 string + IPv6 string } type LogTailConfig struct { @@ -169,6 +172,7 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("derp.server.enabled", false) viper.SetDefault("derp.server.stun.enabled", true) + viper.SetDefault("derp.server.automatically_add_embedded_derp_region", true) viper.SetDefault("unix_socket", "/var/run/headscale/headscale.sock") viper.SetDefault("unix_socket_permission", "0o770") @@ -286,8 +290,14 @@ func GetDERPConfig() DERPConfig { serverRegionCode := viper.GetString("derp.server.region_code") serverRegionName := viper.GetString("derp.server.region_name") stunAddr := viper.GetString("derp.server.stun_listen_addr") - privateKeyPath := util.AbsolutePathFromConfigPath(viper.GetString("derp.server.private_key_path")) - + privateKeyPath := util.AbsolutePathFromConfigPath( + viper.GetString("derp.server.private_key_path"), + ) + ipv4 := viper.GetString("derp.server.ipv4") + ipv6 := viper.GetString("derp.server.ipv6") + automaticallyAddEmbeddedDerpRegion := viper.GetBool( + "derp.server.automatically_add_embedded_derp_region", + ) if serverEnabled && stunAddr == "" { log.Fatal(). Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true") @@ -310,20 +320,28 @@ func GetDERPConfig() DERPConfig { paths := viper.GetStringSlice("derp.paths") + if serverEnabled && !automaticallyAddEmbeddedDerpRegion && len(paths) == 0 { + log.Fatal(). + Msg("Disabling derp.server.automatically_add_embedded_derp_region requires to configure the derp server in derp.paths") + } + autoUpdate := viper.GetBool("derp.auto_update_enabled") updateFrequency := viper.GetDuration("derp.update_frequency") return DERPConfig{ - ServerEnabled: serverEnabled, - ServerRegionID: serverRegionID, - ServerRegionCode: serverRegionCode, - ServerRegionName: serverRegionName, - ServerPrivateKeyPath: privateKeyPath, - STUNAddr: stunAddr, - URLs: urls, - Paths: paths, - AutoUpdate: autoUpdate, - UpdateFrequency: updateFrequency, + ServerEnabled: serverEnabled, + ServerRegionID: serverRegionID, + ServerRegionCode: serverRegionCode, + ServerRegionName: serverRegionName, + ServerPrivateKeyPath: privateKeyPath, + STUNAddr: stunAddr, + URLs: urls, + Paths: paths, + AutoUpdate: autoUpdate, + UpdateFrequency: updateFrequency, + IPv4: ipv4, + IPv6: ipv6, + AutomaticallyAddEmbeddedDerpRegion: automaticallyAddEmbeddedDerpRegion, } } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 9b2ba76..4434264 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -383,7 +383,7 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri // inform peers about smaller changes to the node. // When a field is added to this function, remember to also add it to: // - node.ApplyPeerChange -// - logTracePeerChange in poll.go +// - logTracePeerChange in poll.go. func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange { ret := tailcfg.PeerChange{ NodeID: tailcfg.NodeID(node.ID), diff --git a/integration/general_test.go b/integration/general_test.go index c092844..15c3a72 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -320,7 +320,6 @@ func TestTaildrop(t *testing.T) { if err != nil { t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err) } - } curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"} err = retry(10, 1*time.Second, func() error { From 65376e2842e667d58c0f166b8fc533e09af2b3e0 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 18 Jan 2024 16:36:47 +0100 Subject: [PATCH 02/13] ensure renabled auto-approve routes works (#1670) --- ...v2-TestEnableDisableAutoApprovedRoute.yaml | 67 ++++++++ hscontrol/db/routes.go | 21 ++- hscontrol/poll.go | 8 + integration/route_test.go | 143 ++++++++++++++++++ 4 files changed, 236 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/test-integration-v2-TestEnableDisableAutoApprovedRoute.yaml diff --git a/.github/workflows/test-integration-v2-TestEnableDisableAutoApprovedRoute.yaml b/.github/workflows/test-integration-v2-TestEnableDisableAutoApprovedRoute.yaml new file mode 100644 index 0000000..def07cc --- /dev/null +++ b/.github/workflows/test-integration-v2-TestEnableDisableAutoApprovedRoute.yaml @@ -0,0 +1,67 @@ +# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go +# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ + +name: Integration Test v2 - TestEnableDisableAutoApprovedRoute + +on: [pull_request] + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + TestEnableDisableAutoApprovedRoute: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - uses: DeterminateSystems/nix-installer-action@main + - uses: DeterminateSystems/magic-nix-cache-action@main + - uses: satackey/action-docker-layer-caching@main + continue-on-error: true + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v34 + with: + files: | + *.nix + go.* + **/*.go + integration_test/ + config-example.yaml + + - name: Run TestEnableDisableAutoApprovedRoute + uses: Wandalen/wretry.action@master + if: steps.changed-files.outputs.any_changed == 'true' + with: + attempt_limit: 5 + command: | + nix develop --command -- docker run \ + --tty --rm \ + --volume ~/.cache/hs-integration-go:/go \ + --name headscale-test-suite \ + --volume $PWD:$PWD -w $PWD/integration \ + --volume /var/run/docker.sock:/var/run/docker.sock \ + --volume $PWD/control_logs:/tmp/control \ + golang:1 \ + go run gotest.tools/gotestsum@latest -- ./... \ + -failfast \ + -timeout 120m \ + -parallel 1 \ + -run "^TestEnableDisableAutoApprovedRoute$" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: logs + path: "control_logs/*.log" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: pprof + path: "control_logs/*.pprof.tar" diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index aed9776..dcf00bc 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -639,13 +639,19 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( aclPolicy *policy.ACLPolicy, node *types.Node, ) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + if len(aclPolicy.AutoApprovers.ExitNode) == 0 && len(aclPolicy.AutoApprovers.Routes) == 0 { + // No autoapprovers configured + return nil + } if len(node.IPAddresses) == 0 { - return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs + // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs + return nil } + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + routes, err := hsdb.getNodeAdvertisedRoutes(node) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { log.Error(). @@ -657,6 +663,8 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( return err } + log.Trace().Interface("routes", routes).Msg("routes for autoapproving") + approvedRoutes := types.Routes{} for _, advertisedRoute := range routes { @@ -676,6 +684,13 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( return err } + log.Trace(). + Str("node", node.Hostname). + Str("user", node.User.Name). + Strs("routeApprovers", routeApprovers). + Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()). + Msg("looking up route for autoapproving") + for _, approvedAlias := range routeApprovers { if approvedAlias == node.User.Name { approvedRoutes = append(approvedRoutes, advertisedRoute) diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 568f209..b4ac6b5 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -125,6 +125,14 @@ func (h *Headscale) handlePoll( return } + + if h.ACLPolicy != nil { + // update routes with peer information + err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) + if err != nil { + logErr(err, "Error running auto approved routes") + } + } } // Services is mostly useful for discovery and not critical, diff --git a/integration/route_test.go b/integration/route_test.go index 489165a..3edab6a 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -10,6 +10,7 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" @@ -778,3 +779,145 @@ func TestHASubnetRouterFailover(t *testing.T) { ) } } + +func TestEnableDisableAutoApprovedRoute(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + expectedRoutes := "172.0.0.0/24" + + user := "enable-disable-routing" + + scenario, err := NewScenario() + assertNoErrf(t, "failed to create scenario: %s", err) + defer scenario.Shutdown() + + spec := map[string]int{ + user: 1, + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:approve"})}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy( + &policy.ACLPolicy{ + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + TagOwners: map[string][]string{ + "tag:approve": {user}, + }, + AutoApprovers: policy.AutoApprovers{ + Routes: map[string][]string{ + expectedRoutes: {"tag:approve"}, + }, + }, + }, + )) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + subRouter1 := allClients[0] + + // Initially advertise route + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + expectedRoutes, + } + _, _, err = subRouter1.Execute(command) + assertNoErrf(t, "failed to advertise route: %s", err) + + time.Sleep(10 * time.Second) + + var routes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routes, + ) + assertNoErr(t, err) + assert.Len(t, routes, 1) + + // All routes should be auto approved and enabled + assert.Equal(t, true, routes[0].GetAdvertised()) + assert.Equal(t, true, routes[0].GetEnabled()) + assert.Equal(t, true, routes[0].GetIsPrimary()) + + // Stop advertising route + command = []string{ + "tailscale", + "set", + "--advertise-routes=", + } + _, _, err = subRouter1.Execute(command) + assertNoErrf(t, "failed to remove advertised route: %s", err) + + time.Sleep(10 * time.Second) + + var notAdvertisedRoutes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + ¬AdvertisedRoutes, + ) + assertNoErr(t, err) + assert.Len(t, notAdvertisedRoutes, 1) + + // Route is no longer advertised + assert.Equal(t, false, notAdvertisedRoutes[0].GetAdvertised()) + assert.Equal(t, false, notAdvertisedRoutes[0].GetEnabled()) + assert.Equal(t, true, notAdvertisedRoutes[0].GetIsPrimary()) + + // Advertise route again + command = []string{ + "tailscale", + "set", + "--advertise-routes=" + expectedRoutes, + } + _, _, err = subRouter1.Execute(command) + assertNoErrf(t, "failed to advertise route: %s", err) + + time.Sleep(10 * time.Second) + + var reAdvertisedRoutes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &reAdvertisedRoutes, + ) + assertNoErr(t, err) + assert.Len(t, reAdvertisedRoutes, 1) + + // All routes should be auto approved and enabled + assert.Equal(t, true, reAdvertisedRoutes[0].GetAdvertised()) + assert.Equal(t, true, reAdvertisedRoutes[0].GetEnabled()) + assert.Equal(t, true, reAdvertisedRoutes[0].GetIsPrimary()) +} From 1e22f17f36f8c13185dff269e6a00424b49b9568 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 18 Jan 2024 17:30:25 +0100 Subject: [PATCH 03/13] node selfupdate and fix subnet router when ACL is enabled (#1673) Fixes #1604 Signed-off-by: Kristoffer Dalby --- ...est-integration-v2-TestSubnetRouteACL.yaml | 67 +++++ hscontrol/db/node.go | 13 + hscontrol/mapper/mapper.go | 12 + hscontrol/policy/acls.go | 15 + hscontrol/policy/acls_test.go | 75 +++++ hscontrol/poll.go | 25 ++ integration/route_test.go | 272 ++++++++++++++++++ integration/tailscale.go | 2 + integration/tsic/tsic.go | 25 ++ 9 files changed, 506 insertions(+) create mode 100644 .github/workflows/test-integration-v2-TestSubnetRouteACL.yaml diff --git a/.github/workflows/test-integration-v2-TestSubnetRouteACL.yaml b/.github/workflows/test-integration-v2-TestSubnetRouteACL.yaml new file mode 100644 index 0000000..3cb3f11 --- /dev/null +++ b/.github/workflows/test-integration-v2-TestSubnetRouteACL.yaml @@ -0,0 +1,67 @@ +# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go +# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ + +name: Integration Test v2 - TestSubnetRouteACL + +on: [pull_request] + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + TestSubnetRouteACL: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - uses: DeterminateSystems/nix-installer-action@main + - uses: DeterminateSystems/magic-nix-cache-action@main + - uses: satackey/action-docker-layer-caching@main + continue-on-error: true + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v34 + with: + files: | + *.nix + go.* + **/*.go + integration_test/ + config-example.yaml + + - name: Run TestSubnetRouteACL + uses: Wandalen/wretry.action@master + if: steps.changed-files.outputs.any_changed == 'true' + with: + attempt_limit: 5 + command: | + nix develop --command -- docker run \ + --tty --rm \ + --volume ~/.cache/hs-integration-go:/go \ + --name headscale-test-suite \ + --volume $PWD:$PWD -w $PWD/integration \ + --volume /var/run/docker.sock:/var/run/docker.sock \ + --volume $PWD/control_logs:/tmp/control \ + golang:1 \ + go run gotest.tools/gotestsum@latest -- ./... \ + -failfast \ + -timeout 120m \ + -parallel 1 \ + -run "^TestSubnetRouteACL$" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: logs + path: "control_logs/*.log" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: pprof + path: "control_logs/*.pprof.tar" diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index ce535b9..880a0e1 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -739,6 +739,19 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro stateUpdate, node.MachineKey.String()) } + // Send an update to the node itself with to ensure it + // has an updated packetfilter allowing the new route + // if it is defined in the ACL. + selfUpdate := types.StateUpdate{ + Type: types.StateSelfUpdate, + ChangeNodes: types.Nodes{node}, + } + if selfUpdate.Valid() { + hsdb.notifier.NotifyByMachineKey( + selfUpdate, + node.MachineKey) + } + return nil } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index d6404ce..9998f12 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -278,6 +278,18 @@ func (m *Mapper) LiteMapResponse( return nil, err } + rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( + pol, + node, + nodeMapToList(m.peers), + ) + if err != nil { + return nil, err + } + + resp.PacketFilter = policy.ReduceFilterRules(node, rules) + resp.SSHPolicy = sshPolicy + return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress) } diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 4798d81..1dd664c 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -250,6 +250,21 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F if node.IPAddresses.InIPSet(expanded) { dests = append(dests, dest) } + + // If the node exposes routes, ensure they are note removed + // when the filters are reduced. + if node.Hostinfo != nil { + // TODO(kradalby): Evaluate if we should only keep + // the routes if the route is enabled. This will + // require database access in this part of the code. + if len(node.Hostinfo.RoutableIPs) > 0 { + for _, routableIP := range node.Hostinfo.RoutableIPs { + if expanded.ContainsPrefix(routableIP) { + dests = append(dests, dest) + } + } + } + } } if len(dests) > 0 { diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index c048778..4a74bda 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -1901,6 +1901,81 @@ func TestReduceFilterRules(t *testing.T) { }, want: []tailcfg.FilterRule{}, }, + { + name: "1604-subnet-routers-are-preserved", + pol: ACLPolicy{ + Groups: Groups{ + "group:admins": {"user1"}, + }, + ACLs: []ACL{ + { + Action: "accept", + Sources: []string{"group:admins"}, + Destinations: []string{"group:admins:*"}, + }, + { + Action: "accept", + Sources: []string{"group:admins"}, + Destinations: []string{"10.33.0.0/16:*"}, + }, + }, + }, + node: &types.Node{ + IPAddresses: types.NodeAddresses{ + netip.MustParseAddr("100.64.0.1"), + netip.MustParseAddr("fd7a:115c:a1e0::1"), + }, + User: types.User{Name: "user1"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{ + netip.MustParsePrefix("10.33.0.0/16"), + }, + }, + }, + peers: types.Nodes{ + &types.Node{ + IPAddresses: types.NodeAddresses{ + netip.MustParseAddr("100.64.0.2"), + netip.MustParseAddr("fd7a:115c:a1e0::2"), + }, + User: types.User{Name: "user1"}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.1/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::1/128", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "10.33.0.0/16", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + }, + }, } for _, tt := range tests { diff --git a/hscontrol/poll.go b/hscontrol/poll.go index b4ac6b5..c867f26 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -153,6 +153,8 @@ func (h *Headscale) handlePoll( return } + // Send an update to all peers to propagate the new routes + // available. stateUpdate := types.StateUpdate{ Type: types.StatePeerChanged, ChangeNodes: types.Nodes{node}, @@ -164,6 +166,19 @@ func (h *Headscale) handlePoll( node.MachineKey.String()) } + // Send an update to the node itself with to ensure it + // has an updated packetfilter allowing the new route + // if it is defined in the ACL. + selfUpdate := types.StateUpdate{ + Type: types.StateSelfUpdate, + ChangeNodes: types.Nodes{node}, + } + if selfUpdate.Valid() { + h.nodeNotifier.NotifyByMachineKey( + selfUpdate, + node.MachineKey) + } + return } } @@ -378,6 +393,16 @@ func (h *Headscale) handlePoll( var data []byte var err error + // Ensure the node object is updated, for example, there + // might have been a hostinfo update in a sidechannel + // which contains data needed to generate a map response. + node, err = h.db.GetNodeByMachineKey(node.MachineKey) + if err != nil { + logErr(err, "Could not get machine from db") + + return + } + switch update.Type { case types.StateFullUpdate: logInfo("Sending Full MapResponse") diff --git a/integration/route_test.go b/integration/route_test.go index 3edab6a..741ba24 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -9,11 +9,15 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "tailscale.com/types/ipproto" + "tailscale.com/wgengine/filter" ) // This test is both testing the routes command and the propagation of @@ -921,3 +925,271 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) { assert.Equal(t, true, reAdvertisedRoutes[0].GetEnabled()) assert.Equal(t, true, reAdvertisedRoutes[0].GetIsPrimary()) } + +// TestSubnetRouteACL verifies that Subnet routes are distributed +// as expected when ACLs are activated. +// It implements the issue from +// https://github.com/juanfont/headscale/issues/1604 +func TestSubnetRouteACL(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + user := "subnet-route-acl" + + scenario, err := NewScenario() + assertNoErrf(t, "failed to create scenario: %s", err) + defer scenario.Shutdown() + + spec := map[string]int{ + user: 2, + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy( + &policy.ACLPolicy{ + Groups: policy.Groups{ + "group:admins": {user}, + }, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"group:admins"}, + Destinations: []string{"group:admins:*"}, + }, + { + Action: "accept", + Sources: []string{"group:admins"}, + Destinations: []string{"10.33.0.0/16:*"}, + }, + // { + // Action: "accept", + // Sources: []string{"group:admins"}, + // Destinations: []string{"0.0.0.0/0:*"}, + // }, + }, + }, + )) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + expectedRoutes := map[string]string{ + "1": "10.33.0.0/16", + } + + // Sort nodes by ID + sort.SliceStable(allClients, func(i, j int) bool { + statusI, err := allClients[i].Status() + if err != nil { + return false + } + + statusJ, err := allClients[j].Status() + if err != nil { + return false + } + + return statusI.Self.ID < statusJ.Self.ID + }) + + subRouter1 := allClients[0] + + client := allClients[1] + + // advertise HA route on node 1 and 2 + // ID 1 will be primary + // ID 2 will be secondary + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + if route, ok := expectedRoutes[string(status.Self.ID)]; ok { + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + route, + } + _, _, err = client.Execute(command) + assertNoErrf(t, "failed to advertise route: %s", err) + } + } + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + var routes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routes, + ) + + assertNoErr(t, err) + assert.Len(t, routes, 1) + + for _, route := range routes { + assert.Equal(t, true, route.GetAdvertised()) + assert.Equal(t, false, route.GetEnabled()) + assert.Equal(t, false, route.GetIsPrimary()) + } + + // Verify that no routes has been sent to the client, + // they are not yet enabled. + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Nil(t, peerStatus.PrimaryRoutes) + } + } + + // Enable all routes + for _, route := range routes { + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "enable", + "--route", + strconv.Itoa(int(route.GetId())), + }) + assertNoErr(t, err) + } + + time.Sleep(5 * time.Second) + + var enablingRoutes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &enablingRoutes, + ) + assertNoErr(t, err) + assert.Len(t, enablingRoutes, 1) + + // Node 1 has active route + assert.Equal(t, true, enablingRoutes[0].GetAdvertised()) + assert.Equal(t, true, enablingRoutes[0].GetEnabled()) + assert.Equal(t, true, enablingRoutes[0].GetIsPrimary()) + + // Verify that the client has routes from the primary machine + srs1, _ := subRouter1.Status() + + clientStatus, err := client.Status() + assertNoErr(t, err) + + srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] + + assertNotNil(t, srs1PeerStatus.PrimaryRoutes) + + t.Logf("subnet1 has following routes: %v", srs1PeerStatus.PrimaryRoutes.AsSlice()) + assert.Len(t, srs1PeerStatus.PrimaryRoutes.AsSlice(), 1) + assert.Contains( + t, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), + ) + + clientNm, err := client.Netmap() + assertNoErr(t, err) + + wantClientFilter := []filter.Match{ + { + IPProto: []ipproto.Proto{ + ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, + }, + Srcs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + netip.MustParsePrefix("100.64.0.2/32"), + netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + }, + Dsts: []filter.NetPortRange{ + { + Net: netip.MustParsePrefix("100.64.0.2/32"), + Ports: filter.PortRange{0, 0xffff}, + }, + { + Net: netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + Ports: filter.PortRange{0, 0xffff}, + }, + }, + Caps: []filter.CapMatch{}, + }, + } + + if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.PrefixComparer); diff != "" { + t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff) + } + + subnetNm, err := subRouter1.Netmap() + assertNoErr(t, err) + + wantSubnetFilter := []filter.Match{ + { + IPProto: []ipproto.Proto{ + ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, + }, + Srcs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + netip.MustParsePrefix("100.64.0.2/32"), + netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + }, + Dsts: []filter.NetPortRange{ + { + Net: netip.MustParsePrefix("100.64.0.1/32"), + Ports: filter.PortRange{0, 0xffff}, + }, + { + Net: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + Ports: filter.PortRange{0, 0xffff}, + }, + }, + Caps: []filter.CapMatch{}, + }, + { + IPProto: []ipproto.Proto{ + ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, + }, + Srcs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + netip.MustParsePrefix("100.64.0.2/32"), + netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + }, + Dsts: []filter.NetPortRange{ + { + Net: netip.MustParsePrefix("10.33.0.0/16"), + Ports: filter.PortRange{0, 0xffff}, + }, + }, + Caps: []filter.CapMatch{}, + }, + } + + if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.PrefixComparer); diff != "" { + t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff) + } +} diff --git a/integration/tailscale.go b/integration/tailscale.go index e7bf71b..7187a81 100644 --- a/integration/tailscale.go +++ b/integration/tailscale.go @@ -7,6 +7,7 @@ import ( "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/tsic" "tailscale.com/ipn/ipnstate" + "tailscale.com/types/netmap" ) // nolint @@ -26,6 +27,7 @@ type TailscaleClient interface { IPs() ([]netip.Addr, error) FQDN() (string, error) Status() (*ipnstate.Status, error) + Netmap() (*netmap.NetworkMap, error) WaitForNeedsLogin() error WaitForRunning() error WaitForPeers(expected int) error diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index 7404f6e..c30118d 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -17,6 +17,7 @@ import ( "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "tailscale.com/ipn/ipnstate" + "tailscale.com/types/netmap" ) const ( @@ -519,6 +520,30 @@ func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) { return &status, err } +// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance. +// Only works with Tailscale 1.56.1 and newer. +func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { + command := []string{ + "tailscale", + "debug", + "netmap", + } + + result, stderr, err := t.Execute(command) + if err != nil { + fmt.Printf("stderr: %s\n", stderr) + return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err) + } + + var nm netmap.NetworkMap + err = json.Unmarshal([]byte(result), &nm) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err) + } + + return &nm, err +} + // FQDN returns the FQDN as a string of the Tailscale instance. func (t *TailscaleInContainer) FQDN() (string, error) { if t.fqdn != "" { From a369d57a1736c370b17ac9530be73e1b589f69a9 Mon Sep 17 00:00:00 2001 From: dyz Date: Mon, 22 Jan 2024 00:38:24 +0800 Subject: [PATCH 04/13] fix node expire error due to type in gorm model Update (#1692) Fixes #1674 Signed-off-by: fortitude.zhang --- hscontrol/db/node.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 880a0e1..e2a82cc 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -900,7 +900,7 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { // Do not use setNodeExpiry as that has a notifier hook, which // can cause a deadlock, we are updating all changed nodes later // and there is no point in notifiying twice. - if err := hsdb.db.Model(nodes[index]).Updates(types.Node{ + if err := hsdb.db.Model(&nodes[index]).Updates(types.Node{ Expiry: &started, }).Error; err != nil { log.Error(). From b4210e2c9003792a86386ccef9bf09227b7118f3 Mon Sep 17 00:00:00 2001 From: danielalvsaaker <30574112+danielalvsaaker@users.noreply.github.com> Date: Thu, 25 Jan 2024 09:53:34 +0100 Subject: [PATCH 05/13] Trim client secret after reading from file (#1697) Reading from file will include a line break, which results in a mismatching client secret compared to reading directly from the config. --- hscontrol/types/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 8e61973..d9d5830 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -590,7 +590,7 @@ func GetHeadscaleConfig() (*Config, error) { if err != nil { return nil, err } - oidcClientSecret = string(secretBytes) + oidcClientSecret = strings.TrimSpace(string(secretBytes)) } return &Config{ From 4ea12f472a3c4a632b5a05b96d061368ca5f2604 Mon Sep 17 00:00:00 2001 From: derelm <465155+derelm@users.noreply.github.com> Date: Sat, 3 Feb 2024 15:30:15 +0100 Subject: [PATCH 06/13] Fix failover to disabled route #1706 (#1707) * fix #1706 - failover should disregard disabled routes during failover * fixe tests for failover; all current tests assume routes to be enabled * add testcase for #1706 - failover to disabled route --- hscontrol/db/routes.go | 4 +++ hscontrol/db/routes_test.go | 58 +++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index dcf00bc..8ee91d6 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -585,6 +585,10 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro continue } + if !route.Enabled { + continue + } + if hsdb.notifier.IsConnected(route.Node.MachineKey) { newPrimary = &routes[idx] break diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index d491b6a..1545607 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -371,6 +371,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, routes: types.Routes{ types.Route{ @@ -382,6 +383,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, types.Route{ Model: gorm.Model{ @@ -392,6 +394,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[1], }, IsPrimary: false, + Enabled: true, }, }, want: []key.MachinePublic{ @@ -411,6 +414,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: false, + Enabled: true, }, routes: types.Routes{ types.Route{ @@ -422,6 +426,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, types.Route{ Model: gorm.Model{ @@ -432,6 +437,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[1], }, IsPrimary: false, + Enabled: true, }, }, want: nil, @@ -448,6 +454,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[1], }, IsPrimary: true, + Enabled: true, }, routes: types.Routes{ types.Route{ @@ -459,6 +466,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: false, + Enabled: true, }, types.Route{ Model: gorm.Model{ @@ -469,6 +477,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[1], }, IsPrimary: true, + Enabled: true, }, types.Route{ Model: gorm.Model{ @@ -479,6 +488,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[2], }, IsPrimary: false, + Enabled: true, }, }, want: []key.MachinePublic{ @@ -498,6 +508,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, routes: types.Routes{ types.Route{ @@ -509,6 +520,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, // Offline types.Route{ @@ -520,6 +532,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[3], }, IsPrimary: false, + Enabled: true, }, }, want: nil, @@ -536,6 +549,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, routes: types.Routes{ types.Route{ @@ -547,6 +561,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, // Offline types.Route{ @@ -558,6 +573,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[3], }, IsPrimary: false, + Enabled: true, }, types.Route{ Model: gorm.Model{ @@ -568,6 +584,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[1], }, IsPrimary: true, + Enabled: true, }, }, want: []key.MachinePublic{ @@ -576,6 +593,47 @@ func TestFailoverRoute(t *testing.T) { }, wantErr: false, }, + { + name: "failover-primary-none-enabled", + failingRoute: types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + Enabled: true, + }, + routes: types.Routes{ + types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + Enabled: true, + }, + // not enabled + types.Route{ + Model: gorm.Model{ + ID: 2, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[1], + }, + IsPrimary: false, + Enabled: false, + }, + }, + want: nil, + wantErr: false, + }, } for _, tt := range tests { From cbf57e27a78922c88edd8de8f02c99810a653786 Mon Sep 17 00:00:00 2001 From: DeveloperDragon <42499964+TotoTheDragon@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:45:35 +0100 Subject: [PATCH 07/13] Login with OIDC after having been logged out (#1719) --- hscontrol/auth.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 4fe5a16..9b44c2d 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -199,6 +199,19 @@ func (h *Headscale) handleRegister( return } + // When logged out and reauthenticating with OIDC, the OldNodeKey is not passed, but the NodeKey has changed + if node.NodeKey.String() != registerRequest.NodeKey.String() && + registerRequest.OldNodeKey.IsZero() && !node.IsExpired() { + h.handleNodeKeyRefresh( + writer, + registerRequest, + *node, + machineKey, + ) + + return + } + if registerRequest.Followup != "" { select { case <-req.Context().Done(): From 83769ba715408c05cc5defc1562e0bfe1d368de6 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 8 Feb 2024 17:28:19 +0100 Subject: [PATCH 08/13] Replace database locks with transactions (#1701) This commits removes the locks used to guard data integrity for the database and replaces them with Transactions, turns out that SQL had a way to deal with this all along. This reduces the complexity we had with multiple locks that might stack or recurse (database, nofitifer, mapper). All notifications and state updates are now triggered _after_ a database change. Signed-off-by: Kristoffer Dalby --- CHANGELOG.md | 2 +- cmd/headscale/headscale.go | 14 - hscontrol/app.go | 58 ++- hscontrol/auth.go | 63 ++- hscontrol/db/addresses.go | 19 +- hscontrol/db/addresses_test.go | 46 ++- hscontrol/db/api_key.go | 33 +- hscontrol/db/db.go | 62 ++- hscontrol/db/node.go | 561 ++++++++++----------------- hscontrol/db/node_test.go | 80 ++-- hscontrol/db/preauth_keys.go | 121 +++--- hscontrol/db/preauth_keys_test.go | 20 +- hscontrol/db/routes.go | 425 ++++++++++---------- hscontrol/db/routes_test.go | 323 ++++++++++++--- hscontrol/db/suite_test.go | 2 - hscontrol/db/users.go | 104 ++--- hscontrol/db/users_test.go | 6 +- hscontrol/derp/server/derp_server.go | 2 +- hscontrol/grpcv1.go | 187 +++++++-- hscontrol/mapper/mapper.go | 19 +- hscontrol/mapper/tail.go | 2 +- hscontrol/notifier/notifier.go | 97 +++-- hscontrol/oidc.go | 33 +- hscontrol/policy/acls.go | 47 ++- hscontrol/poll.go | 75 +++- hscontrol/poll_noise.go | 2 +- hscontrol/types/common.go | 38 ++ hscontrol/types/users.go | 17 +- integration/cli_test.go | 4 +- integration/general_test.go | 126 +++--- integration/route_test.go | 32 +- integration/scenario.go | 4 +- 32 files changed, 1496 insertions(+), 1128 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a5441a5..a7908ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473) - API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553) - Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611) - - The latest supported client is 1.36 + - The latest supported client is 1.38 - Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564) - If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url. - Embedded DERP server requires a private key [#1611](https://github.com/juanfont/headscale/pull/1611) diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index dfaf512..3f3322e 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -6,25 +6,11 @@ import ( "github.com/efekarakus/termcolor" "github.com/juanfont/headscale/cmd/headscale/cli" - "github.com/pkg/profile" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) func main() { - if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile { - if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok { - err := os.MkdirAll(profilePath, os.ModePerm) - if err != nil { - log.Fatal().Err(err).Msg("failed to create profiling directory") - } - - defer profile.Start(profile.ProfilePath(profilePath)).Stop() - } else { - defer profile.Start().Stop() - } - } - var colors bool switch l := termcolor.SupportLevel(os.Stderr); l { case termcolor.Level16M: diff --git a/hscontrol/app.go b/hscontrol/app.go index 75dfdde..91d5326 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -33,6 +33,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" + "github.com/pkg/profile" "github.com/prometheus/client_golang/prometheus/promhttp" zl "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -48,6 +49,7 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" + "gorm.io/gorm" "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" @@ -61,7 +63,7 @@ var ( "unknown value for Lets Encrypt challenge type", ) errEmptyInitialDERPMap = errors.New( - "initial DERPMap is empty, Headscale requries at least one entry", + "initial DERPMap is empty, Headscale requires at least one entry", ) ) @@ -166,7 +168,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { cfg.DBtype, dbString, app.dbDebug, - app.nodeNotifier, cfg.IPPrefixes, cfg.BaseDomain) if err != nil { @@ -234,8 +235,23 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { // seen for longer than h.cfg.EphemeralNodeInactivityTimeout. func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) + + var update types.StateUpdate + var changed bool for range ticker.C { - h.db.ExpireEphemeralNodes(h.cfg.EphemeralNodeInactivityTimeout) + if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout) + + return nil + }); err != nil { + log.Error().Err(err).Msg("database error while expiring ephemeral nodes") + continue + } + + if changed && update.Valid() { + ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") + h.nodeNotifier.NotifyAll(ctx, update) + } } } @@ -246,9 +262,24 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) { ticker := time.NewTicker(interval) lastCheck := time.Unix(0, 0) + var update types.StateUpdate + var changed bool for range ticker.C { - lastCheck = h.db.ExpireExpiredNodes(lastCheck) + if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck) + + return nil + }); err != nil { + log.Error().Err(err).Msg("database error while expiring nodes") + continue + } + + log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes") + if changed && update.Valid() { + ctx := types.NotifyCtx(context.Background(), "expire-expired", "na") + h.nodeNotifier.NotifyAll(ctx, update) + } } } @@ -278,7 +309,8 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { DERPMap: h.DERPMap, } if stateUpdate.Valid() { - h.nodeNotifier.NotifyAll(stateUpdate) + ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na") + h.nodeNotifier.NotifyAll(ctx, stateUpdate) } } } @@ -485,6 +517,19 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { // Serve launches a GIN server with the Headscale API. func (h *Headscale) Serve() error { + if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile { + if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok { + err := os.MkdirAll(profilePath, os.ModePerm) + if err != nil { + log.Fatal().Err(err).Msg("failed to create profiling directory") + } + + defer profile.Start(profile.ProfilePath(profilePath)).Stop() + } else { + defer profile.Start().Stop() + } + } + var err error // Fetch an initial DERP Map before we start serving @@ -753,7 +798,8 @@ func (h *Headscale) Serve() error { Str("path", aclPath). Msg("ACL policy successfully reloaded, notifying nodes of change") - h.nodeNotifier.NotifyAll(types.StateUpdate{ + ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na") + h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ Type: types.StateFullUpdate, }) } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 9b44c2d..3e9557a 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -1,6 +1,7 @@ package hscontrol import ( + "context" "encoding/json" "errors" "fmt" @@ -8,6 +9,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -243,8 +245,6 @@ func (h *Headscale) handleRegister( // handleAuthKey contains the logic to manage auth key client registration // When using Noise, the machineKey is Zero. -// -// TODO: check if any locks are needed around IP allocation. func (h *Headscale) handleAuthKey( writer http.ResponseWriter, registerRequest tailcfg.RegisterRequest, @@ -311,6 +311,9 @@ func (h *Headscale) handleAuthKey( nodeKey := registerRequest.NodeKey + var update types.StateUpdate + var mkey key.MachinePublic + // retrieve node information if it exist // The error is not important, because if it does not // exist, then this is a new node and we will move @@ -324,7 +327,7 @@ func (h *Headscale) handleAuthKey( node.NodeKey = nodeKey node.AuthKeyID = uint(pak.ID) - err := h.db.NodeSetExpiry(node, registerRequest.Expiry) + err := h.db.NodeSetExpiry(node.ID, registerRequest.Expiry) if err != nil { log.Error(). Caller(). @@ -335,10 +338,13 @@ func (h *Headscale) handleAuthKey( return } + mkey = node.MachineKey + update = types.StateUpdateExpire(node.ID, registerRequest.Expiry) + aclTags := pak.Proto().GetAclTags() if len(aclTags) > 0 { // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login - err = h.db.SetTags(node, aclTags) + err = h.db.SetTags(node.ID, aclTags) if err != nil { log.Error(). @@ -370,6 +376,7 @@ func (h *Headscale) handleAuthKey( Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, UserID: pak.User.ID, + User: pak.User, MachineKey: machineKey, RegisterMethod: util.RegisterMethodAuthKey, Expiry: ®isterRequest.Expiry, @@ -393,9 +400,18 @@ func (h *Headscale) handleAuthKey( return } + + mkey = node.MachineKey + update = types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from auth.handleAuthKey", + } } - err = h.db.UsePreAuthKey(pak) + err = h.db.DB.Transaction(func(tx *gorm.DB) error { + return db.UsePreAuthKey(tx, pak) + }) if err != nil { log.Error(). Caller(). @@ -437,6 +453,13 @@ func (h *Headscale) handleAuthKey( Caller(). Err(err). Msg("Failed to write response") + return + } + + // TODO(kradalby): if notifying after register make sense. + if update.Valid() { + ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na") + h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String()) } log.Info(). @@ -502,7 +525,7 @@ func (h *Headscale) handleNodeLogOut( Msg("Client requested logout") now := time.Now() - err := h.db.NodeSetExpiry(&node, now) + err := h.db.NodeSetExpiry(node.ID, now) if err != nil { log.Error(). Caller(). @@ -513,17 +536,10 @@ func (h *Headscale) handleNodeLogOut( return } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: &now, - }, - }, - } + stateUpdate := types.StateUpdateExpire(node.ID, now) if stateUpdate.Valid() { - h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) + ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na") + h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) } resp.AuthURL = "" @@ -554,7 +570,7 @@ func (h *Headscale) handleNodeLogOut( } if node.IsEphemeral() { - err = h.db.DeleteNode(&node) + err = h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap()) if err != nil { log.Error(). Err(err). @@ -562,6 +578,15 @@ func (h *Headscale) handleNodeLogOut( Msg("Cannot delete ephemeral node from the database") } + stateUpdate := types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na") + h.nodeNotifier.NotifyAll(ctx, stateUpdate) + } + return } @@ -633,7 +658,9 @@ func (h *Headscale) handleNodeKeyRefresh( Str("node", node.Hostname). Msg("We have the OldNodeKey in the database. This is a key refresh") - err := h.db.NodeSetNodeKey(&node, registerRequest.NodeKey) + err := h.db.DB.Transaction(func(tx *gorm.DB) error { + return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey) + }) if err != nil { log.Error(). Caller(). diff --git a/hscontrol/db/addresses.go b/hscontrol/db/addresses.go index beccf84..5857870 100644 --- a/hscontrol/db/addresses.go +++ b/hscontrol/db/addresses.go @@ -13,16 +13,23 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" + "gorm.io/gorm" ) var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) { + return getAvailableIPs(rx, hsdb.ipPrefixes) + }) +} + +func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) { var ips types.NodeAddresses var err error - for _, ipPrefix := range hsdb.ipPrefixes { + for _, ipPrefix := range ipPrefixes { var ip *netip.Addr - ip, err = hsdb.getAvailableIP(ipPrefix) + ip, err = getAvailableIP(rx, ipPrefix) if err != nil { return ips, err } @@ -32,8 +39,8 @@ func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) { return ips, err } -func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) { - usedIps, err := hsdb.getUsedIPs() +func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) { + usedIps, err := getUsedIPs(rx) if err != nil { return nil, err } @@ -64,12 +71,12 @@ func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, erro } } -func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) { +func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) { // FIXME: This really deserves a better data model, // but this was quick to get running and it should be enough // to begin experimenting with a dual stack tailnet. var addressesSlices []string - hsdb.db.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices) + rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices) var ips netipx.IPSetBuilder for _, slice := range addressesSlices { diff --git a/hscontrol/db/addresses_test.go b/hscontrol/db/addresses_test.go index 07059ea..ef33659 100644 --- a/hscontrol/db/addresses_test.go +++ b/hscontrol/db/addresses_test.go @@ -7,10 +7,16 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" "gopkg.in/check.v1" + "gorm.io/gorm" ) func (s *Suite) TestGetAvailableIp(c *check.C) { - ips, err := db.getAvailableIPs() + tx := db.DB.Begin() + defer tx.Rollback() + + ips, err := getAvailableIPs(tx, []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + }) c.Assert(err, check.IsNil) @@ -30,7 +36,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) node := types.Node{ @@ -41,10 +47,13 @@ func (s *Suite) TestGetUsedIps(c *check.C) { AuthKeyID: uint(pak.ID), IPAddresses: ips, } - db.db.Save(&node) - - usedIps, err := db.getUsedIPs() + db.Write(func(tx *gorm.DB) error { + return tx.Save(&node).Error + }) + usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) { + return getUsedIPs(rx) + }) c.Assert(err, check.IsNil) expected := netip.MustParseAddr("10.27.0.1") @@ -63,19 +72,23 @@ func (s *Suite) TestGetUsedIps(c *check.C) { } func (s *Suite) TestGetMultiIp(c *check.C) { - user, err := db.CreateUser("test-ip-multi") + user, err := db.CreateUser("test-ip") c.Assert(err, check.IsNil) + ipPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + } + for index := 1; index <= 350; index++ { - db.ipAllocationMutex.Lock() + tx := db.DB.Begin() - ips, err := db.getAvailableIPs() + ips, err := getAvailableIPs(tx, ipPrefixes) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = getNode(tx, "test", "testnode") c.Assert(err, check.NotNil) node := types.Node{ @@ -86,12 +99,13 @@ func (s *Suite) TestGetMultiIp(c *check.C) { AuthKeyID: uint(pak.ID), IPAddresses: ips, } - db.db.Save(&node) - - db.ipAllocationMutex.Unlock() + tx.Save(&node) + c.Assert(tx.Commit().Error, check.IsNil) } - usedIps, err := db.getUsedIPs() + usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) { + return getUsedIPs(rx) + }) c.Assert(err, check.IsNil) expected0 := netip.MustParseAddr("10.27.0.1") @@ -162,7 +176,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) node := types.Node{ @@ -172,7 +186,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) ips2, err := db.getAvailableIPs() c.Assert(err, check.IsNil) diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index bc8dc2b..5108314 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -22,9 +22,6 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, ) (string, *types.APIKey, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) if err != nil { return "", nil, err @@ -49,7 +46,7 @@ func (hsdb *HSDatabase) CreateAPIKey( Expiration: expiration, } - if err := hsdb.db.Save(&key).Error; err != nil { + if err := hsdb.DB.Save(&key).Error; err != nil { return "", nil, fmt.Errorf("failed to save API key to database: %w", err) } @@ -58,11 +55,8 @@ func (hsdb *HSDatabase) CreateAPIKey( // ListAPIKeys returns the list of ApiKeys for a user. func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - keys := []types.APIKey{} - if err := hsdb.db.Find(&keys).Error; err != nil { + if err := hsdb.DB.Find(&keys).Error; err != nil { return nil, err } @@ -71,11 +65,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { // GetAPIKey returns a ApiKey for a given key. func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - key := types.APIKey{} - if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { + if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error } @@ -84,11 +75,8 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { // GetAPIKeyByID returns a ApiKey for a given id. func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - key := types.APIKey{} - if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { + if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error } @@ -98,10 +86,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // does not exist. func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { + if result := hsdb.DB.Unscoped().Delete(key); result.Error != nil { return result.Error } @@ -110,10 +95,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { // ExpireAPIKey marks a ApiKey as expired. func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { + if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -121,9 +103,6 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { } func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - prefix, hash, found := strings.Cut(keyStr, ".") if !found { return false, ErrAPIKeyFailedToParse diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 030a6f0..df7b0a4 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -7,12 +7,10 @@ import ( "fmt" "net/netip" "strings" - "sync" "time" "github.com/glebarez/sqlite" "github.com/go-gormigrate/gormigrate/v2" - "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -36,12 +34,7 @@ type KV struct { } type HSDatabase struct { - db *gorm.DB - notifier *notifier.Notifier - - mu sync.RWMutex - - ipAllocationMutex sync.Mutex + DB *gorm.DB ipPrefixes []netip.Prefix baseDomain string @@ -52,7 +45,6 @@ type HSDatabase struct { func NewHeadscaleDatabase( dbType, connectionAddr string, debug bool, - notifier *notifier.Notifier, ipPrefixes []netip.Prefix, baseDomain string, ) (*HSDatabase, error) { @@ -147,7 +139,9 @@ func NewHeadscaleDatabase( DiscoKey string } var results []result - err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error + err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes"). + Find(&results). + Error if err != nil { return err } @@ -180,7 +174,8 @@ func NewHeadscaleDatabase( } if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") { - log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...") + log.Info(). + Msgf("Database has legacy enabled_routes column in node, migrating...") type NodeAux struct { ID uint64 @@ -317,8 +312,7 @@ func NewHeadscaleDatabase( } db := HSDatabase{ - db: dbConn, - notifier: notifier, + DB: dbConn, ipPrefixes: ipPrefixes, baseDomain: baseDomain, @@ -376,7 +370,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { func (hsdb *HSDatabase) PingDB(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - sqlDB, err := hsdb.db.DB() + sqlDB, err := hsdb.DB.DB() if err != nil { return err } @@ -385,10 +379,48 @@ func (hsdb *HSDatabase) PingDB(ctx context.Context) error { } func (hsdb *HSDatabase) Close() error { - db, err := hsdb.db.DB() + db, err := hsdb.DB.DB() if err != nil { return err } return db.Close() } + +func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error { + rx := hsdb.DB.Begin() + defer rx.Rollback() + return fn(rx) +} + +func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) { + rx := db.Begin() + defer rx.Rollback() + ret, err := fn(rx) + if err != nil { + var no T + return no, err + } + return ret, nil +} + +func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error { + tx := hsdb.DB.Begin() + defer tx.Rollback() + if err := fn(tx); err != nil { + return err + } + + return tx.Commit().Error +} + +func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) { + tx := db.Begin() + defer tx.Rollback() + ret, err := fn(tx) + if err != nil { + var no T + return no, err + } + return ret, tx.Commit().Error +} diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index e2a82cc..a747429 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -34,22 +34,21 @@ var ( ) ) -// ListPeers returns all peers of node, regardless of any Policy or if the node is expired. func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listPeers(node) + return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + return ListPeers(rx, node) + }) } -func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) { +// ListPeers returns all peers of node, regardless of any Policy or if the node is expired. +func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) { log.Trace(). Caller(). Str("node", node.Hostname). Msg("Finding direct peers") nodes := types.Nodes{} - if err := hsdb.db. + if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -64,16 +63,15 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) { return nodes, nil } -func (hsdb *HSDatabase) ListNodes() ([]types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listNodes() +func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + return ListNodes(rx) + }) } -func (hsdb *HSDatabase) listNodes() ([]types.Node, error) { - nodes := []types.Node{} - if err := hsdb.db. +func ListNodes(tx *gorm.DB) (types.Nodes, error) { + nodes := types.Nodes{} + if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -85,16 +83,9 @@ func (hsdb *HSDatabase) listNodes() ([]types.Node, error) { return nodes, nil } -func (hsdb *HSDatabase) ListNodesByGivenName(givenName string) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listNodesByGivenName(givenName) -} - -func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, error) { +func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) { nodes := types.Nodes{} - if err := hsdb.db. + if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -106,12 +97,15 @@ func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, err return nodes, nil } -// GetNode finds a Node by name and user and returns the Node struct. -func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() +func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return getNode(rx, user, name) + }) +} - nodes, err := hsdb.ListNodesByUser(user) +// getNode finds a Node by name and user and returns the Node struct. +func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) { + nodes, err := ListNodesByUser(tx, user) if err != nil { return nil, err } @@ -125,34 +119,16 @@ func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) { return nil, ErrNodeNotFound } -// GetNodeByGivenName finds a Node by given name and user and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByGivenName( - user string, - givenName string, -) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - node := types.Node{} - if err := hsdb.db. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - Where("given_name = ?", givenName).First(&node).Error; err != nil { - return nil, err - } - - return nil, ErrNodeNotFound +func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByID(rx, id) + }) } // GetNodeByID finds a Node by ID and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) { mach := types.Node{} - if result := hsdb.db. + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -164,21 +140,19 @@ func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { return &mach, nil } -// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByMachineKey( - machineKey key.MachinePublic, -) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeByMachineKey(machineKey) +func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByMachineKey(rx, machineKey) + }) } -func (hsdb *HSDatabase) getNodeByMachineKey( +// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. +func GetNodeByMachineKey( + tx *gorm.DB, machineKey key.MachinePublic, ) (*types.Node, error) { mach := types.Node{} - if result := hsdb.db. + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -190,36 +164,24 @@ func (hsdb *HSDatabase) getNodeByMachineKey( return &mach, nil } -// GetNodeByNodeKey finds a Node by its current NodeKey. -func (hsdb *HSDatabase) GetNodeByNodeKey( +func (hsdb *HSDatabase) GetNodeByAnyKey( + machineKey key.MachinePublic, nodeKey key.NodePublic, + oldNodeKey key.NodePublic, ) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - node := types.Node{} - if result := hsdb.db. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - First(&node, "node_key = ?", - nodeKey.String()); result.Error != nil { - return nil, result.Error - } - - return &node, nil + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey) + }) } // GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByAnyKey( +// TODO(kradalby): see if we can remove this. +func GetNodeByAnyKey( + tx *gorm.DB, machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, ) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - node := types.Node{} - if result := hsdb.db. + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -234,60 +196,44 @@ func (hsdb *HSDatabase) GetNodeByAnyKey( return &node, nil } -func (hsdb *HSDatabase) NodeReloadFromDatabase(node *types.Node) error { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - if result := hsdb.db.Find(node).First(&node); result.Error != nil { - return result.Error - } - - return nil +func (hsdb *HSDatabase) SetTags( + nodeID uint64, + tags []string, +) error { + return hsdb.Write(func(tx *gorm.DB) error { + return SetTags(tx, nodeID, tags) + }) } // SetTags takes a Node struct pointer and update the forced tags. -func (hsdb *HSDatabase) SetTags( - node *types.Node, +func SetTags( + tx *gorm.DB, + nodeID uint64, tags []string, ) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - if len(tags) == 0 { return nil } - newTags := []string{} + newTags := types.StringList{} for _, tag := range tags { if !util.StringOrPrefixListContains(newTags, tag) { newTags = append(newTags, tag) } } - if err := hsdb.db.Model(node).Updates(types.Node{ - ForcedTags: newTags, - }).Error; err != nil { + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil { return fmt.Errorf("failed to update tags for node in the database: %w", err) } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from db.SetTags", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) - } - return nil } // RenameNode takes a Node struct and a new GivenName for the nodes // and renames it. -func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func RenameNode(tx *gorm.DB, + nodeID uint64, newName string, +) error { err := util.CheckForFQDNRules( newName, ) @@ -295,129 +241,74 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error { log.Error(). Caller(). Str("func", "RenameNode"). - Str("node", node.Hostname). + Uint64("nodeID", nodeID). Str("newName", newName). Err(err). Msg("failed to rename node") return err } - node.GivenName = newName - if err := hsdb.db.Model(node).Updates(types.Node{ - GivenName: newName, - }).Error; err != nil { + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { return fmt.Errorf("failed to rename node in the database: %w", err) } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from db.RenameNode", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) - } - return nil } +func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error { + return hsdb.Write(func(tx *gorm.DB) error { + return NodeSetExpiry(tx, nodeID, expiry) + }) +} + // NodeSetExpiry takes a Node struct and a new expiry time. -func (hsdb *HSDatabase) NodeSetExpiry(node *types.Node, expiry time.Time) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.nodeSetExpiry(node, expiry) +func NodeSetExpiry(tx *gorm.DB, + nodeID uint64, expiry time.Time, +) error { + return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error } -func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error { - if err := hsdb.db.Model(node).Updates(types.Node{ - Expiry: &expiry, - }).Error; err != nil { - return fmt.Errorf( - "failed to refresh node (update expiration) in the database: %w", - err, - ) - } - - node.Expiry = &expiry - - stateSelfUpdate := types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: types.Nodes{node}, - } - if stateSelfUpdate.Valid() { - hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey) - } - - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: &expiry, - }, - }, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) - } - - return nil +func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.MachinePublic]bool) error { + return hsdb.Write(func(tx *gorm.DB) error { + return DeleteNode(tx, node, isConnected) + }) } // DeleteNode deletes a Node from the database. -func (hsdb *HSDatabase) DeleteNode(node *types.Node) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.deleteNode(node) -} - -func (hsdb *HSDatabase) deleteNode(node *types.Node) error { - err := hsdb.deleteNodeRoutes(node) +// Caller is responsible for notifying all of change. +func DeleteNode(tx *gorm.DB, + node *types.Node, + isConnected map[key.MachinePublic]bool, +) error { + err := deleteNodeRoutes(tx, node, map[key.MachinePublic]bool{}) if err != nil { return err } // Unscoped causes the node to be fully removed from the database. - if err := hsdb.db.Unscoped().Delete(&node).Error; err != nil { + if err := tx.Unscoped().Delete(&node).Error; err != nil { return err } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - return nil } // UpdateLastSeen sets a node's last seen field indicating that we // have recently communicating with this node. -// This is mostly used to indicate if a node is online and is not -// extremely important to make sure is fully correct and to avoid -// holding up the hot path, does not contain any locks and isnt -// concurrency safe. But that should be ok. -func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error { - return hsdb.db.Model(node).Updates(types.Node{ - LastSeen: node.LastSeen, - }).Error +func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error { + return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error } -func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( +func RegisterNodeFromAuthCallback( + tx *gorm.DB, cache *cache.Cache, mkey key.MachinePublic, userName string, nodeExpiry *time.Time, registrationMethod string, + ipPrefixes []netip.Prefix, ) (*types.Node, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - log.Debug(). Str("machine_key", mkey.ShortString()). Str("userName", userName). @@ -427,7 +318,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( if nodeInterface, ok := cache.Get(mkey.String()); ok { if registrationNode, ok := nodeInterface.(types.Node); ok { - user, err := hsdb.getUser(userName) + user, err := GetUser(tx, userName) if err != nil { return nil, fmt.Errorf( "failed to find user in register node from auth callback, %w", @@ -442,14 +333,17 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( } registrationNode.UserID = user.ID + registrationNode.User = *user registrationNode.RegisterMethod = registrationMethod if nodeExpiry != nil { registrationNode.Expiry = nodeExpiry } - node, err := hsdb.registerNode( + node, err := RegisterNode( + tx, registrationNode, + ipPrefixes, ) if err == nil { @@ -465,15 +359,14 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( return nil, ErrNodeNotFoundRegistrationCache } -// RegisterNode is executed from the CLI to register a new Node using its MachineKey. func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.registerNode(node) + return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { + return RegisterNode(tx, node, hsdb.ipPrefixes) + }) } -func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { +// RegisterNode is executed from the CLI to register a new Node using its MachineKey. +func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) { log.Debug(). Str("node", node.Hostname). Str("machine_key", node.MachineKey.ShortString()). @@ -485,7 +378,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { // so we store the node.Expire and node.Nodekey that has been set when // adding it to the registrationCache if len(node.IPAddresses) > 0 { - if err := hsdb.db.Save(&node).Error; err != nil { + if err := tx.Save(&node).Error; err != nil { return nil, fmt.Errorf("failed register existing node in the database: %w", err) } @@ -500,10 +393,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { return &node, nil } - hsdb.ipAllocationMutex.Lock() - defer hsdb.ipAllocationMutex.Unlock() - - ips, err := hsdb.getAvailableIPs() + ips, err := getAvailableIPs(tx, ipPrefixes) if err != nil { log.Error(). Caller(). @@ -516,7 +406,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { node.IPAddresses = ips - if err := hsdb.db.Save(&node).Error; err != nil { + if err := tx.Save(&node).Error; err != nil { return nil, fmt.Errorf("failed register(save) node in the database: %w", err) } @@ -530,61 +420,50 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { } // NodeSetNodeKey sets the node key of a node and saves it to the database. -func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if err := hsdb.db.Model(node).Updates(types.Node{ +func NodeSetNodeKey(tx *gorm.DB, node *types.Node, nodeKey key.NodePublic) error { + return tx.Model(node).Updates(types.Node{ NodeKey: nodeKey, - }).Error; err != nil { - return err - } - - return nil + }).Error } -// NodeSetMachineKey sets the node key of a node and saves it to the database. func (hsdb *HSDatabase) NodeSetMachineKey( node *types.Node, machineKey key.MachinePublic, ) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + return hsdb.Write(func(tx *gorm.DB) error { + return NodeSetMachineKey(tx, node, machineKey) + }) +} - if err := hsdb.db.Model(node).Updates(types.Node{ +// NodeSetMachineKey sets the node key of a node and saves it to the database. +func NodeSetMachineKey( + tx *gorm.DB, + node *types.Node, + machineKey key.MachinePublic, +) error { + return tx.Model(node).Updates(types.Node{ MachineKey: machineKey, - }).Error; err != nil { - return err - } - - return nil + }).Error } // NodeSave saves a node object to the database, prefer to use a specific save method rather // than this. It is intended to be used when we are changing or. -func (hsdb *HSDatabase) NodeSave(node *types.Node) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() +// TODO(kradalby): Remove this func, just use Save. +func NodeSave(tx *gorm.DB, node *types.Node) error { + return tx.Save(node).Error +} - if err := hsdb.db.Save(node).Error; err != nil { - return err - } - - return nil +func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { + return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) { + return GetAdvertisedRoutes(rx, node) + }) } // GetAdvertisedRoutes returns the routes that are be advertised by the given node. -func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getAdvertisedRoutes(node) -} - -func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { +func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) { routes := types.Routes{} - err := hsdb.db. + err := tx. Preload("Node"). Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { @@ -605,18 +484,17 @@ func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, e return prefixes, nil } -// GetEnabledRoutes returns the routes that are enabled for the node. func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getEnabledRoutes(node) + return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) { + return GetEnabledRoutes(rx, node) + }) } -func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { +// GetEnabledRoutes returns the routes that are enabled for the node. +func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) { routes := types.Routes{} - err := hsdb.db. + err := tx. Preload("Node"). Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true). Find(&routes).Error @@ -638,16 +516,13 @@ func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, erro return prefixes, nil } -func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool { route, err := netip.ParsePrefix(routeStr) if err != nil { return false } - enabledRoutes, err := hsdb.getEnabledRoutes(node) + enabledRoutes, err := GetEnabledRoutes(tx, node) if err != nil { log.Error().Err(err).Msg("Could not get enabled routes") @@ -663,26 +538,37 @@ func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool return false } +func (hsdb *HSDatabase) enableRoutes( + node *types.Node, + routeStrs ...string, +) (*types.StateUpdate, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return enableRoutes(tx, node, routeStrs...) + }) +} + // enableRoutes enables new routes based on a list of new routes. -func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error { +func enableRoutes(tx *gorm.DB, + node *types.Node, routeStrs ...string, +) (*types.StateUpdate, error) { newRoutes := make([]netip.Prefix, len(routeStrs)) for index, routeStr := range routeStrs { route, err := netip.ParsePrefix(routeStr) if err != nil { - return err + return nil, err } newRoutes[index] = route } - advertisedRoutes, err := hsdb.getAdvertisedRoutes(node) + advertisedRoutes, err := GetAdvertisedRoutes(tx, node) if err != nil { - return err + return nil, err } for _, newRoute := range newRoutes { if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) { - return fmt.Errorf( + return nil, fmt.Errorf( "route (%s) is not available on node %s: %w", node.Hostname, newRoute, ErrNodeRouteIsNotAvailable, @@ -693,7 +579,7 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro // Separate loop so we don't leave things in a half-updated state for _, prefix := range newRoutes { route := types.Route{} - err := hsdb.db.Preload("Node"). + err := tx.Preload("Node"). Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). First(&route).Error if err == nil { @@ -702,23 +588,23 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro // Mark already as primary if there is only this node offering this subnet // (and is not an exit route) if !route.IsExitRoute() { - route.IsPrimary = hsdb.isUniquePrefix(route) + route.IsPrimary = isUniquePrefix(tx, route) } - err = hsdb.db.Save(&route).Error + err = tx.Save(&route).Error if err != nil { - return fmt.Errorf("failed to enable route: %w", err) + return nil, fmt.Errorf("failed to enable route: %w", err) } } else { - return fmt.Errorf("failed to find route: %w", err) + return nil, fmt.Errorf("failed to find route: %w", err) } } // Ensure the node has the latest routes when notifying the other // nodes - nRoutes, err := hsdb.getNodeRoutes(node) + nRoutes, err := GetNodeRoutes(tx, node) if err != nil { - return fmt.Errorf("failed to read back routes: %w", err) + return nil, fmt.Errorf("failed to read back routes: %w", err) } node.Routes = nRoutes @@ -729,30 +615,11 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro Strs("routes", routeStrs). Msg("enabling routes") - stateUpdate := types.StateUpdate{ + return &types.StateUpdate{ Type: types.StatePeerChanged, ChangeNodes: types.Nodes{node}, - Message: "called from db.enableRoutes", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore( - stateUpdate, node.MachineKey.String()) - } - - // Send an update to the node itself with to ensure it - // has an updated packetfilter allowing the new route - // if it is defined in the ACL. - selfUpdate := types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: types.Nodes{node}, - } - if selfUpdate.Valid() { - hsdb.notifier.NotifyByMachineKey( - selfUpdate, - node.MachineKey) - } - - return nil + Message: "created in db.enableRoutes", + }, nil } func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { @@ -785,16 +652,23 @@ func (hsdb *HSDatabase) GenerateGivenName( mkey key.MachinePublic, suppliedName string, ) (string, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() + return Read(hsdb.DB, func(rx *gorm.DB) (string, error) { + return GenerateGivenName(rx, mkey, suppliedName) + }) +} +func GenerateGivenName( + tx *gorm.DB, + mkey key.MachinePublic, + suppliedName string, +) (string, error) { givenName, err := generateGivenName(suppliedName, false) if err != nil { return "", err } // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - nodes, err := hsdb.listNodesByGivenName(givenName) + nodes, err := listNodesByGivenName(tx, givenName) if err != nil { return "", err } @@ -818,29 +692,28 @@ func (hsdb *HSDatabase) GenerateGivenName( return givenName, nil } -func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - users, err := hsdb.listUsers() +func ExpireEphemeralNodes(tx *gorm.DB, + inactivityThreshhold time.Duration, +) (types.StateUpdate, bool) { + users, err := ListUsers(tx) if err != nil { log.Error().Err(err).Msg("Error listing users") - return + return types.StateUpdate{}, false } + expired := make([]tailcfg.NodeID, 0) for _, user := range users { - nodes, err := hsdb.listNodesByUser(user.Name) + nodes, err := ListNodesByUser(tx, user.Name) if err != nil { log.Error(). Err(err). Str("user", user.Name). Msg("Error listing nodes in user") - return + return types.StateUpdate{}, false } - expired := make([]tailcfg.NodeID, 0) for idx, node := range nodes { if node.IsEphemeral() && node.LastSeen != nil && time.Now(). @@ -851,7 +724,8 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) Str("node", node.Hostname). Msg("Ephemeral client removed from database") - err = hsdb.deleteNode(nodes[idx]) + // empty isConnected map as ephemeral nodes are not routes + err = DeleteNode(tx, nodes[idx], map[key.MachinePublic]bool{}) if err != nil { log.Error(). Err(err). @@ -861,33 +735,35 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) } } - if len(expired) > 0 { - hsdb.notifier.NotifyAll(types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: expired, - }) - } + // TODO(kradalby): needs to be moved out of transaction } + if len(expired) > 0 { + return types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: expired, + }, true + } + + return types.StateUpdate{}, false } -func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func ExpireExpiredNodes(tx *gorm.DB, + lastCheck time.Time, +) (time.Time, types.StateUpdate, bool) { // use the time of the start of the function to ensure we // dont miss some nodes by returning it _after_ we have // checked everything. started := time.Now() - expiredNodes := make([]*types.Node, 0) + expired := make([]*tailcfg.PeerChange, 0) - nodes, err := hsdb.listNodes() + nodes, err := ListNodes(tx) if err != nil { log.Error(). Err(err). Msg("Error listing nodes to find expired nodes") - return time.Unix(0, 0) + return time.Unix(0, 0), types.StateUpdate{}, false } for index, node := range nodes { if node.IsExpired() && @@ -895,13 +771,17 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { // It will notify about all nodes that has been expired. // It should only notify about expired nodes since _last check_. node.Expiry.After(lastCheck) { - expiredNodes = append(expiredNodes, &nodes[index]) + expired = append(expired, &tailcfg.PeerChange{ + NodeID: tailcfg.NodeID(node.ID), + KeyExpiry: node.Expiry, + }) + now := time.Now() // Do not use setNodeExpiry as that has a notifier hook, which // can cause a deadlock, we are updating all changed nodes later // and there is no point in notifiying twice. - if err := hsdb.db.Model(&nodes[index]).Updates(types.Node{ - Expiry: &started, + if err := tx.Model(&nodes[index]).Updates(types.Node{ + Expiry: &now, }).Error; err != nil { log.Error(). Err(err). @@ -917,33 +797,12 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { } } - expired := make([]*tailcfg.PeerChange, len(expiredNodes)) - for idx, node := range expiredNodes { - expired[idx] = &tailcfg.PeerChange{ - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: &started, - } + if len(expired) > 0 { + return started, types.StateUpdate{ + Type: types.StatePeerChangedPatch, + ChangePatches: expired, + }, true } - // Inform the peers of a node with a lightweight update. - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: expired, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - - // Inform the node itself that it has expired. - for _, node := range expiredNodes { - stateSelfUpdate := types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: types.Nodes{node}, - } - if stateSelfUpdate.Valid() { - hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey) - } - } - - return started + return started, types.StateUpdate{}, false } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 140c264..5e8eb29 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -23,7 +23,7 @@ func (s *Suite) TestGetNode(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -38,9 +38,9 @@ func (s *Suite) TestGetNode(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(node) + db.DB.Save(node) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) } @@ -66,40 +66,12 @@ func (s *Suite) TestGetNodeByID(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) _, err = db.GetNodeByID(0) c.Assert(err, check.IsNil) } -func (s *Suite) TestGetNodeByNodeKey(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - db.db.Save(&node) - - _, err = db.GetNodeByNodeKey(nodeKey.Public()) - c.Assert(err, check.IsNil) -} - func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) @@ -124,7 +96,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) _, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) c.Assert(err, check.IsNil) @@ -146,12 +118,12 @@ func (s *Suite) TestHardDeleteNode(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(1), } - db.db.Save(&node) + db.DB.Save(&node) - err = db.DeleteNode(&node) + err = db.DeleteNode(&node, map[key.MachinePublic]bool{}) c.Assert(err, check.IsNil) - _, err = db.GetNode(user.Name, "testnode3") + _, err = db.getNode(user.Name, "testnode3") c.Assert(err, check.NotNil) } @@ -178,7 +150,7 @@ func (s *Suite) TestListPeers(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) } node0ByID, err := db.GetNodeByID(0) @@ -228,7 +200,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(stor[index%2].key.ID), } - db.db.Save(&node) + db.DB.Save(&node) } aclPolicy := &policy.ACLPolicy{ @@ -295,7 +267,7 @@ func (s *Suite) TestExpireNode(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -311,16 +283,19 @@ func (s *Suite) TestExpireNode(c *check.C) { AuthKeyID: uint(pak.ID), Expiry: &time.Time{}, } - db.db.Save(node) + db.DB.Save(node) - nodeFromDB, err := db.GetNode("test", "testnode") + nodeFromDB, err := db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB, check.NotNil) c.Assert(nodeFromDB.IsExpired(), check.Equals, false) now := time.Now() - err = db.NodeSetExpiry(nodeFromDB, now) + err = db.NodeSetExpiry(nodeFromDB.ID, now) + c.Assert(err, check.IsNil) + + nodeFromDB, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB.IsExpired(), check.Equals, true) @@ -354,7 +329,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("user-1", "testnode") + _, err = db.getNode("user-1", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -372,7 +347,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(node) + db.DB.Save(node) givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2") comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") @@ -397,7 +372,7 @@ func (s *Suite) TestSetTags(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -412,21 +387,21 @@ func (s *Suite) TestSetTags(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(node) + db.DB.Save(node) // assign simple tags sTags := []string{"tag:test", "tag:foo"} - err = db.SetTags(node, sTags) + err = db.SetTags(node.ID, sTags) c.Assert(err, check.IsNil) - node, err = db.GetNode("test", "testnode") + node, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags)) // assign duplicat tags, expect no errors but no doubles in DB eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} - err = db.SetTags(node, eTags) + err = db.SetTags(node.ID, eTags) c.Assert(err, check.IsNil) - node, err = db.GetNode("test", "testnode") + node, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert( node.ForcedTags, @@ -601,7 +576,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, } - db.db.Save(&node) + db.DB.Save(&node) sendUpdate, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -610,7 +585,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { node0ByID, err := db.GetNodeByID(0) c.Assert(err, check.IsNil) - err = db.EnableAutoApprovedRoutes(pol, node0ByID) + // TODO(kradalby): Check state update + _, err = db.EnableAutoApprovedRoutes(pol, node0ByID) c.Assert(err, check.IsNil) enabledRoutes, err := db.GetEnabledRoutes(node0ByID) diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index e743988..0fdb822 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -20,7 +20,6 @@ var ( ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") ) -// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. func (hsdb *HSDatabase) CreatePreAuthKey( userName string, reusable bool, @@ -28,11 +27,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey( expiration *time.Time, aclTags []string, ) (*types.PreAuthKey, error) { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { + return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags) + }) +} - user, err := hsdb.GetUser(userName) +// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. +func CreatePreAuthKey( + tx *gorm.DB, + userName string, + reusable bool, + ephemeral bool, + expiration *time.Time, + aclTags []string, +) (*types.PreAuthKey, error) { + user, err := GetUser(tx, userName) if err != nil { return nil, err } @@ -48,7 +57,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( } now := time.Now().UTC() - kstr, err := hsdb.generateKey() + kstr, err := generateKey() if err != nil { return nil, err } @@ -63,29 +72,25 @@ func (hsdb *HSDatabase) CreatePreAuthKey( Expiration: expiration, } - err = hsdb.db.Transaction(func(db *gorm.DB) error { - if err := db.Save(&key).Error; err != nil { - return fmt.Errorf("failed to create key in the database: %w", err) - } + if err := tx.Save(&key).Error; err != nil { + return nil, fmt.Errorf("failed to create key in the database: %w", err) + } - if len(aclTags) > 0 { - seenTags := map[string]bool{} + if len(aclTags) > 0 { + seenTags := map[string]bool{} - for _, tag := range aclTags { - if !seenTags[tag] { - if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { - return fmt.Errorf( - "failed to ceate key tag in the database: %w", - err, - ) - } - seenTags[tag] = true + for _, tag := range aclTags { + if !seenTags[tag] { + if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { + return nil, fmt.Errorf( + "failed to ceate key tag in the database: %w", + err, + ) } + seenTags[tag] = true } } - - return nil - }) + } if err != nil { return nil, err @@ -94,22 +99,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey( return &key, nil } -// ListPreAuthKeys returns the list of PreAuthKeys for a user. func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listPreAuthKeys(userName) + return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { + return ListPreAuthKeys(rx, userName) + }) } -func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) { - user, err := hsdb.getUser(userName) +// ListPreAuthKeys returns the list of PreAuthKeys for a user. +func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) { + user, err := GetUser(tx, userName) if err != nil { return nil, err } keys := []types.PreAuthKey{} - if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { + if err := tx.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { return nil, err } @@ -117,11 +121,8 @@ func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, er } // GetPreAuthKey returns a PreAuthKey for a given key. -func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - pak, err := hsdb.ValidatePreAuthKey(key) +func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) { + pak, err := ValidatePreAuthKey(tx, key) if err != nil { return nil, err } @@ -135,15 +136,8 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. -func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.destroyPreAuthKey(pak) -} - -func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { - return hsdb.db.Transaction(func(db *gorm.DB) error { +func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error { + return tx.Transaction(func(db *gorm.DB) error { if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { return result.Error } @@ -156,12 +150,15 @@ func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { }) } -// MarkExpirePreAuthKey marks a PreAuthKey as expired. func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + return hsdb.Write(func(tx *gorm.DB) error { + return ExpirePreAuthKey(tx, k) + }) +} - if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { +// MarkExpirePreAuthKey marks a PreAuthKey as expired. +func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { + if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -169,26 +166,26 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { } // UsePreAuthKey marks a PreAuthKey as used. -func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { k.Used = true - if err := hsdb.db.Save(k).Error; err != nil { + if err := tx.Save(k).Error; err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) } return nil } +func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) { + return ValidatePreAuthKey(rx, k) + }) +} + // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. -func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) { pak := types.PreAuthKey{} - if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( + if result := tx.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -204,7 +201,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) } nodes := types.Nodes{} - if err := hsdb.db. + if err := tx. Preload("AuthKey"). Where(&types.Node{AuthKeyID: uint(pak.ID)}). Find(&nodes).Error; err != nil { @@ -218,7 +215,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) return &pak, nil } -func (hsdb *HSDatabase) generateKey() (string, error) { +func generateKey() (string, error) { size := 24 bytes := make([]byte, size) if _, err := rand.Read(bytes); err != nil { diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index df9c2a1..003a396 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -6,6 +6,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" + "gorm.io/gorm" ) func (*Suite) TestCreatePreAuthKey(c *check.C) { @@ -41,7 +42,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) { user, err := db.CreateUser("test2") c.Assert(err, check.IsNil) - now := time.Now() + now := time.Now().Add(-5 * time.Second) pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) c.Assert(err, check.IsNil) @@ -82,7 +83,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) @@ -103,7 +104,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) @@ -138,19 +139,22 @@ func (*Suite) TestEphemeralKey(c *check.C) { LastSeen: &now, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) _, err = db.ValidatePreAuthKey(pak.Key) // Ephemeral keys are by definition reusable c.Assert(err, check.IsNil) - _, err = db.GetNode("test7", "testest") + _, err = db.getNode("test7", "testest") c.Assert(err, check.IsNil) - db.ExpireEphemeralNodes(time.Second * 20) + db.DB.Transaction(func(tx *gorm.DB) error { + ExpireEphemeralNodes(tx, time.Second*20) + return nil + }) // The machine record should have been deleted - _, err = db.GetNode("test7", "testest") + _, err = db.getNode("test7", "testest") c.Assert(err, check.NotNil) } @@ -178,7 +182,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true - db.db.Save(&pak) + db.DB.Save(&pak) _, err = db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 8ee91d6..1ee144a 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -7,23 +7,15 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" - "github.com/samber/lo" "gorm.io/gorm" "tailscale.com/types/key" ) var ErrRouteIsNotAvailable = errors.New("route is not available") -func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getRoutes() -} - -func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { +func GetRoutes(tx *gorm.DB) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Find(&routes).Error @@ -34,9 +26,9 @@ func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { return routes, nil } -func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) { +func getAdvertisedAndEnabledRoutes(tx *gorm.DB) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("advertised = ? AND enabled = ?", true, true). @@ -48,9 +40,9 @@ func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) { return routes, nil } -func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) { +func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("prefix = ?", types.IPPrefix(pref)). @@ -62,16 +54,9 @@ func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, erro return routes, nil } -func (hsdb *HSDatabase) GetNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeAdvertisedRoutes(node) -} - -func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) { +func GetNodeAdvertisedRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("node_id = ? AND advertised = true", node.ID). @@ -84,15 +69,14 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, } func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeRoutes(node) + return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) { + return GetNodeRoutes(rx, node) + }) } -func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) { +func GetNodeRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("node_id = ?", node.ID). @@ -104,16 +88,9 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) { return routes, nil } -func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getRoute(id) -} - -func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { +func GetRoute(tx *gorm.DB, id uint64) (*types.Route, error) { var route types.Route - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). First(&route, id).Error @@ -124,40 +101,34 @@ func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { return &route, nil } -func (hsdb *HSDatabase) EnableRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.enableRoute(id) -} - -func (hsdb *HSDatabase) enableRoute(id uint64) error { - route, err := hsdb.getRoute(id) +func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) { + route, err := GetRoute(tx, id) if err != nil { - return err + return nil, err } // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 if route.IsExitRoute() { - return hsdb.enableRoutes( + return enableRoutes( + tx, &route.Node, types.ExitRouteV4.String(), types.ExitRouteV6.String(), ) } - return hsdb.enableRoutes(&route.Node, netip.Prefix(route.Prefix).String()) + return enableRoutes(tx, &route.Node, netip.Prefix(route.Prefix).String()) } -func (hsdb *HSDatabase) DisableRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - route, err := hsdb.getRoute(id) +func DisableRoute(tx *gorm.DB, + id uint64, + isConnected map[key.MachinePublic]bool, +) (*types.StateUpdate, error) { + route, err := GetRoute(tx, id) if err != nil { - return err + return nil, err } var routes types.Routes @@ -166,64 +137,79 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 + var update *types.StateUpdate if !route.IsExitRoute() { - err = hsdb.failoverRouteWithNotify(route) + update, err = failoverRouteReturnUpdate(tx, isConnected, route) if err != nil { - return err + return nil, err } route.Enabled = false route.IsPrimary = false - err = hsdb.db.Save(route).Error + err = tx.Save(route).Error if err != nil { - return err + return nil, err } } else { - routes, err = hsdb.getNodeRoutes(&node) + routes, err = GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } for i := range routes { if routes[i].IsExitRoute() { routes[i].Enabled = false routes[i].IsPrimary = false - err = hsdb.db.Save(&routes[i]).Error + err = tx.Save(&routes[i]).Error if err != nil { - return err + return nil, err } } } } if routes == nil { - routes, err = hsdb.getNodeRoutes(&node) + routes, err = GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } } node.Routes = routes - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{&node}, - Message: "called from db.DisableRoute", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) + // If update is empty, it means that one was not created + // by failover (as a failover was not necessary), create + // one and return to the caller. + if update == nil { + update = &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{ + &node, + }, + Message: "called from db.DisableRoute", + } } - return nil + return update, nil } -func (hsdb *HSDatabase) DeleteRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() +func (hsdb *HSDatabase) DeleteRoute( + id uint64, + isConnected map[key.MachinePublic]bool, +) (*types.StateUpdate, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return DeleteRoute(tx, id, isConnected) + }) +} - route, err := hsdb.getRoute(id) +func DeleteRoute( + tx *gorm.DB, + id uint64, + isConnected map[key.MachinePublic]bool, +) (*types.StateUpdate, error) { + route, err := GetRoute(tx, id) if err != nil { - return err + return nil, err } var routes types.Routes @@ -232,19 +218,20 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 + var update *types.StateUpdate if !route.IsExitRoute() { - err := hsdb.failoverRouteWithNotify(route) + update, err = failoverRouteReturnUpdate(tx, isConnected, route) if err != nil { - return nil + return nil, nil } - if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { - return err + if err := tx.Unscoped().Delete(&route).Error; err != nil { + return nil, err } } else { - routes, err := hsdb.getNodeRoutes(&node) + routes, err := GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } routesToDelete := types.Routes{} @@ -254,56 +241,59 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { } } - if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil { - return err + if err := tx.Unscoped().Delete(&routesToDelete).Error; err != nil { + return nil, err } } + // If update is empty, it means that one was not created + // by failover (as a failover was not necessary), create + // one and return to the caller. if routes == nil { - routes, err = hsdb.getNodeRoutes(&node) + routes, err = GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } } node.Routes = routes - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{&node}, - Message: "called from db.DeleteRoute", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) + if update == nil { + update = &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{ + &node, + }, + Message: "called from db.DeleteRoute", + } } - return nil + return update, nil } -func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error { - routes, err := hsdb.getNodeRoutes(node) +func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error { + routes, err := GetNodeRoutes(tx, node) if err != nil { return err } for i := range routes { - if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil { + if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil { return err } // TODO(kradalby): This is a bit too aggressive, we could probably // figure out which routes needs to be failed over rather than all. - hsdb.failoverRouteWithNotify(&routes[i]) + failoverRouteReturnUpdate(tx, isConnected, &routes[i]) } return nil } // isUniquePrefix returns if there is another node providing the same route already. -func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { +func isUniquePrefix(tx *gorm.DB, route types.Route) bool { var count int64 - hsdb.db. - Model(&types.Route{}). + tx.Model(&types.Route{}). Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?", route.Prefix, route.NodeID, @@ -312,9 +302,9 @@ func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { return count == 0 } -func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) { +func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) { var route types.Route - err := hsdb.db. + err := tx. Preload("Node"). Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true). First(&route).Error @@ -329,14 +319,17 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro return &route, nil } +func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) { + return GetNodePrimaryRoutes(rx, node) + }) +} + // getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary. -func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func GetNodePrimaryRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true). Find(&routes).Error @@ -347,22 +340,21 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er return routes, nil } +func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (bool, error) { + return SaveNodeRoutes(tx, node) + }) +} + // SaveNodeRoutes takes a node and updates the database with // the new routes. // It returns a bool whether an update should be sent as the // saved route impacts nodes. -func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.saveNodeRoutes(node) -} - -func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { +func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) { sendUpdate := false currentRoutes := types.Routes{} - err := hsdb.db.Where("node_id = ?", node.ID).Find(¤tRoutes).Error + err := tx.Where("node_id = ?", node.ID).Find(¤tRoutes).Error if err != nil { return sendUpdate, err } @@ -382,7 +374,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { if !route.Advertised { currentRoutes[pos].Advertised = true - err := hsdb.db.Save(¤tRoutes[pos]).Error + err := tx.Save(¤tRoutes[pos]).Error if err != nil { return sendUpdate, err } @@ -398,7 +390,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { } else if route.Advertised { currentRoutes[pos].Advertised = false currentRoutes[pos].Enabled = false - err := hsdb.db.Save(¤tRoutes[pos]).Error + err := tx.Save(¤tRoutes[pos]).Error if err != nil { return sendUpdate, err } @@ -413,7 +405,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { Advertised: true, Enabled: false, } - err := hsdb.db.Create(&route).Error + err := tx.Create(&route).Error if err != nil { return sendUpdate, err } @@ -425,127 +417,89 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { // EnsureFailoverRouteIsAvailable takes a node and checks if the node's route // currently have a functioning host that exposes the network. -func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error { - nodeRoutes, err := hsdb.getNodeRoutes(node) +func EnsureFailoverRouteIsAvailable( + tx *gorm.DB, + isConnected map[key.MachinePublic]bool, + node *types.Node, +) (*types.StateUpdate, error) { + nodeRoutes, err := GetNodeRoutes(tx, node) if err != nil { - return nil + return nil, nil } + var changedNodes types.Nodes for _, nodeRoute := range nodeRoutes { - routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix)) + routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix)) if err != nil { - return err + return nil, err } for _, route := range routes { if route.IsPrimary { // if we have a primary route, and the node is connected // nothing needs to be done. - if hsdb.notifier.IsConnected(route.Node.MachineKey) { + if isConnected[route.Node.MachineKey] { continue } // if not, we need to failover the route - err := hsdb.failoverRouteWithNotify(&route) + update, err := failoverRouteReturnUpdate(tx, isConnected, &route) if err != nil { - return err + return nil, err + } + + if update != nil { + changedNodes = append(changedNodes, update.ChangeNodes...) } } } } - return nil -} - -func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error { - routes, err := hsdb.getNodeRoutes(node) - if err != nil { - return nil - } - - var changedKeys []key.MachinePublic - - for _, route := range routes { - changed, err := hsdb.failoverRoute(&route) - if err != nil { - return err - } - - changedKeys = append(changedKeys, changed...) - } - - changedKeys = lo.Uniq(changedKeys) - - var nodes types.Nodes - - for _, key := range changedKeys { - node, err := hsdb.GetNodeByMachineKey(key) - if err != nil { - return err - } - - nodes = append(nodes, node) - } - - if nodes != nil { - stateUpdate := types.StateUpdate{ + if len(changedNodes) != 0 { + return &types.StateUpdate{ Type: types.StatePeerChanged, - ChangeNodes: nodes, - Message: "called from db.FailoverNodeRoutesWithNotify", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } + ChangeNodes: changedNodes, + Message: "called from db.EnsureFailoverRouteIsAvailable", + }, nil } - return nil + return nil, nil } -func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error { - changedKeys, err := hsdb.failoverRoute(r) +func failoverRouteReturnUpdate( + tx *gorm.DB, + isConnected map[key.MachinePublic]bool, + r *types.Route, +) (*types.StateUpdate, error) { + changedKeys, err := failoverRoute(tx, isConnected, r) if err != nil { - return err + return nil, err } + log.Trace(). + Interface("isConnected", isConnected). + Interface("changedKeys", changedKeys). + Msg("building route failover") + if len(changedKeys) == 0 { - return nil + return nil, nil } var nodes types.Nodes - - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("loading machines with new primary routes from db") - for _, key := range changedKeys { - node, err := hsdb.getNodeByMachineKey(key) + node, err := GetNodeByMachineKey(tx, key) if err != nil { - return err + return nil, err } nodes = append(nodes, node) } - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("notifying peers about primary route change") - - if nodes != nil { - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: nodes, - Message: "called from db.failoverRouteWithNotify", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - } - - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("notified peers about primary route change") - - return nil + return &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: nodes, + Message: "called from db.failoverRouteReturnUpdate", + }, nil } // failoverRoute takes a route that is no longer available, @@ -556,12 +510,16 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error { // // and tries to find a new route to take over its place. // If the given route was not primary, it returns early. -func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) { +func failoverRoute( + tx *gorm.DB, + isConnected map[key.MachinePublic]bool, + r *types.Route, +) ([]key.MachinePublic, error) { if r == nil { return nil, nil } - // This route is not a primary route, and it isnt + // This route is not a primary route, and it is not // being served to nodes. if !r.IsPrimary { return nil, nil @@ -572,7 +530,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro return nil, nil } - routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix)) + routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix)) if err != nil { return nil, err } @@ -589,14 +547,14 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro continue } - if hsdb.notifier.IsConnected(route.Node.MachineKey) { + if isConnected[route.Node.MachineKey] { newPrimary = &routes[idx] break } } // If a new route was not found/available, - // return with an error. + // return without an error. // We do not want to update the database as // the one currently marked as primary is the // best we got. @@ -610,7 +568,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro // Remove primary from the old route r.IsPrimary = false - err = hsdb.db.Save(&r).Error + err = tx.Save(&r).Error if err != nil { log.Error().Err(err).Msg("error disabling new primary route") @@ -623,7 +581,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro // Set primary for the new primary newPrimary.IsPrimary = true - err = hsdb.db.Save(&newPrimary).Error + err = tx.Save(&newPrimary).Error if err != nil { log.Error().Err(err).Msg("error enabling new primary route") @@ -638,25 +596,26 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil } -// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. func (hsdb *HSDatabase) EnableAutoApprovedRoutes( aclPolicy *policy.ACLPolicy, node *types.Node, -) error { - if len(aclPolicy.AutoApprovers.ExitNode) == 0 && len(aclPolicy.AutoApprovers.Routes) == 0 { - // No autoapprovers configured - return nil - } +) (*types.StateUpdate, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return EnableAutoApprovedRoutes(tx, aclPolicy, node) + }) +} +// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. +func EnableAutoApprovedRoutes( + tx *gorm.DB, + aclPolicy *policy.ACLPolicy, + node *types.Node, +) (*types.StateUpdate, error) { if len(node.IPAddresses) == 0 { - // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs - return nil + return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs } - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - routes, err := hsdb.getNodeAdvertisedRoutes(node) + routes, err := GetNodeAdvertisedRoutes(tx, node) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { log.Error(). Caller(). @@ -664,7 +623,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( Str("node", node.Hostname). Msg("Could not get advertised routes for node") - return err + return nil, err } log.Trace().Interface("routes", routes).Msg("routes for autoapproving") @@ -685,7 +644,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( Uint64("nodeId", node.ID). Msg("Failed to resolve autoApprovers for advertised route") - return err + return nil, err } log.Trace(). @@ -706,7 +665,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( Str("alias", approvedAlias). Msg("Failed to expand alias when processing autoApprovers policy") - return err + return nil, err } // approvedIPs should contain all of node's IPs if it matches the rule, so check for first @@ -717,17 +676,25 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( } } + update := &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{}, + Message: "created in db.EnableAutoApprovedRoutes", + } + for _, approvedRoute := range approvedRoutes { - err := hsdb.enableRoute(uint64(approvedRoute.ID)) + perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID)) if err != nil { log.Err(err). Str("approvedRoute", approvedRoute.String()). Uint64("nodeId", node.ID). Msg("Failed to enable approved route") - return err + return nil, err } + + update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...) } - return nil + return update, nil } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 1545607..3b544aa 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" @@ -24,7 +23,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_get_route_node") + _, err = db.getNode("test", "test_get_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix("10.0.0.0/24") @@ -42,7 +41,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo, } - db.db.Save(&node) + db.DB.Save(&node) su, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -52,10 +51,11 @@ func (s *Suite) TestGetRoutes(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(advertisedRoutes), check.Equals, 1) - err = db.enableRoutes(&node, "192.168.0.0/24") + // TODO(kradalby): check state update + _, err = db.enableRoutes(&node, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = db.enableRoutes(&node, "10.0.0.0/24") + _, err = db.enableRoutes(&node, "10.0.0.0/24") c.Assert(err, check.IsNil) } @@ -66,7 +66,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_enable_route_node") + _, err = db.getNode("test", "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -91,7 +91,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo, } - db.db.Save(&node) + db.DB.Save(&node) sendUpdate, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -106,10 +106,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(noEnabledRoutes), check.Equals, 0) - err = db.enableRoutes(&node, "192.168.0.0/24") + _, err = db.enableRoutes(&node, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = db.enableRoutes(&node, "10.0.0.0/24") + _, err = db.enableRoutes(&node, "10.0.0.0/24") c.Assert(err, check.IsNil) enabledRoutes, err := db.GetEnabledRoutes(&node) @@ -117,14 +117,14 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { c.Assert(len(enabledRoutes), check.Equals, 1) // Adding it twice will just let it pass through - err = db.enableRoutes(&node, "10.0.0.0/24") + _, err = db.enableRoutes(&node, "10.0.0.0/24") c.Assert(err, check.IsNil) enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node) c.Assert(err, check.IsNil) c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) - err = db.enableRoutes(&node, "150.0.10.0/25") + _, err = db.enableRoutes(&node, "150.0.10.0/25") c.Assert(err, check.IsNil) enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node) @@ -139,7 +139,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_enable_route_node") + _, err = db.getNode("test", "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -163,16 +163,16 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo1, } - db.db.Save(&node1) + db.DB.Save(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) c.Assert(sendUpdate, check.Equals, false) - err = db.enableRoutes(&node1, route.String()) + _, err = db.enableRoutes(&node1, route.String()) c.Assert(err, check.IsNil) - err = db.enableRoutes(&node1, route2.String()) + _, err = db.enableRoutes(&node1, route2.String()) c.Assert(err, check.IsNil) hostInfo2 := tailcfg.Hostinfo{ @@ -186,13 +186,13 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo2, } - db.db.Save(&node2) + db.DB.Save(&node2) sendUpdate, err = db.SaveNodeRoutes(&node2) c.Assert(err, check.IsNil) c.Assert(sendUpdate, check.Equals, false) - err = db.enableRoutes(&node2, route2.String()) + _, err = db.enableRoutes(&node2, route2.String()) c.Assert(err, check.IsNil) enabledRoutes1, err := db.GetEnabledRoutes(&node1) @@ -219,7 +219,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_enable_route_node") + _, err = db.getNode("test", "test_enable_route_node") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -246,22 +246,23 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { Hostinfo: &hostInfo1, LastSeen: &now, } - db.db.Save(&node1) + db.DB.Save(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) c.Assert(sendUpdate, check.Equals, false) - err = db.enableRoutes(&node1, prefix.String()) + _, err = db.enableRoutes(&node1, prefix.String()) c.Assert(err, check.IsNil) - err = db.enableRoutes(&node1, prefix2.String()) + _, err = db.enableRoutes(&node1, prefix2.String()) c.Assert(err, check.IsNil) routes, err := db.GetNodeRoutes(&node1) c.Assert(err, check.IsNil) - err = db.DeleteRoute(uint64(routes[0].ID)) + // TODO(kradalby): check stateupdate + _, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{}) c.Assert(err, check.IsNil) enabledRoutes1, err := db.GetEnabledRoutes(&node1) @@ -269,17 +270,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { c.Assert(len(enabledRoutes1), check.Equals, 1) } +var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } + func TestFailoverRoute(t *testing.T) { - ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } - - // TODO(kradalby): Count/verify updates - var sink chan types.StateUpdate - - go func() { - for range sink { - } - }() - machineKeys := []key.MachinePublic{ key.NewMachine().Public(), key.NewMachine().Public(), @@ -291,6 +284,7 @@ func TestFailoverRoute(t *testing.T) { name string failingRoute types.Route routes types.Routes + isConnected map[key.MachinePublic]bool want []key.MachinePublic wantErr bool }{ @@ -397,6 +391,10 @@ func TestFailoverRoute(t *testing.T) { Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: false, + machineKeys[1]: true, + }, want: []key.MachinePublic{ machineKeys[0], machineKeys[1], @@ -491,6 +489,11 @@ func TestFailoverRoute(t *testing.T) { Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: true, + machineKeys[1]: true, + machineKeys[2]: true, + }, want: []key.MachinePublic{ machineKeys[1], machineKeys[0], @@ -535,6 +538,10 @@ func TestFailoverRoute(t *testing.T) { Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: true, + machineKeys[3]: false, + }, want: nil, wantErr: false, }, @@ -587,6 +594,11 @@ func TestFailoverRoute(t *testing.T) { Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: false, + machineKeys[1]: true, + machineKeys[3]: false, + }, want: []key.MachinePublic{ machineKeys[0], machineKeys[1], @@ -641,13 +653,10 @@ func TestFailoverRoute(t *testing.T) { tmpDir, err := os.MkdirTemp("", "failover-db-test") assert.NoError(t, err) - notif := notifier.NewNotifier() - db, err = NewHeadscaleDatabase( "sqlite3", tmpDir+"/headscale_test.db", false, - notif, []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, @@ -655,23 +664,15 @@ func TestFailoverRoute(t *testing.T) { ) assert.NoError(t, err) - // Pretend that all the nodes are connected to control - for idx, key := range machineKeys { - // Pretend one node is offline - if idx == 3 { - continue - } - - notif.AddNode(key, sink) - } - for _, route := range tt.routes { - if err := db.db.Save(&route).Error; err != nil { + if err := db.DB.Save(&route).Error; err != nil { t.Fatalf("failed to create route: %s", err) } } - got, err := db.failoverRoute(&tt.failingRoute) + got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) { + return failoverRoute(tx, tt.isConnected, &tt.failingRoute) + }) if (err != nil) != tt.wantErr { t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr) @@ -685,3 +686,231 @@ func TestFailoverRoute(t *testing.T) { }) } } + +// func TestDisableRouteFailover(t *testing.T) { +// machineKeys := []key.MachinePublic{ +// key.NewMachine().Public(), +// key.NewMachine().Public(), +// key.NewMachine().Public(), +// key.NewMachine().Public(), +// } + +// tests := []struct { +// name string +// nodes types.Nodes + +// routeID uint64 +// isConnected map[key.MachinePublic]bool + +// wantMachineKey key.MachinePublic +// wantErr string +// }{ +// { +// name: "single-route", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// Node: types.Node{ +// MachineKey: machineKeys[0], +// }, +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[0], +// }, +// { +// name: "failover-simple", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// &types.Node{ +// ID: 1, +// MachineKey: machineKeys[1], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 2, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: false, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[1], +// }, +// { +// name: "no-failover-offline", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// &types.Node{ +// ID: 1, +// MachineKey: machineKeys[1], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 2, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: false, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// isConnected: map[key.MachinePublic]bool{ +// machineKeys[0]: true, +// machineKeys[1]: false, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[1], +// }, +// { +// name: "failover-to-online", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// &types.Node{ +// ID: 1, +// MachineKey: machineKeys[1], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 2, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: false, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// isConnected: map[key.MachinePublic]bool{ +// machineKeys[0]: true, +// machineKeys[1]: true, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[1], +// }, +// } + +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "") +// assert.NoError(t, err) + +// // bootstrap db +// datab.DB.Transaction(func(tx *gorm.DB) error { +// for _, node := range tt.nodes { +// err := tx.Save(node).Error +// if err != nil { +// return err +// } + +// _, err = SaveNodeRoutes(tx, node) +// if err != nil { +// return err +// } +// } + +// return nil +// }) + +// got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { +// return DisableRoute(tx, tt.routeID, tt.isConnected) +// }) + +// // if (err.Error() != "") != tt.wantErr { +// // t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr) + +// // return +// // } + +// if len(got.ChangeNodes) != 1 { +// t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes)) +// } + +// if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" { +// t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff) +// } +// }) +// } +// } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 1c38491..d4b11b1 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -6,7 +6,6 @@ import ( "os" "testing" - "github.com/juanfont/headscale/hscontrol/notifier" "gopkg.in/check.v1" ) @@ -48,7 +47,6 @@ func (s *Suite) ResetDB(c *check.C) { "sqlite3", tmpDir+"/headscale_test.db", false, - notifier.NewNotifier(), []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 27a1406..99e9339 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -15,22 +15,25 @@ var ( ErrUserStillHasNodes = errors.New("user not empty: node(s) found") ) +func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.User, error) { + return CreateUser(tx, name) + }) +} + // CreateUser creates a new User. Returns error if could not be created // or another user already exists. -func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func CreateUser(tx *gorm.DB, name string) (*types.User, error) { err := util.CheckForFQDNRules(name) if err != nil { return nil, err } user := types.User{} - if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { + if err := tx.Where("name = ?", name).First(&user).Error; err == nil { return nil, ErrUserExists } user.Name = name - if err := hsdb.db.Create(&user).Error; err != nil { + if err := tx.Create(&user).Error; err != nil { log.Error(). Str("func", "CreateUser"). Err(err). @@ -42,18 +45,21 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { return &user, nil } +func (hsdb *HSDatabase) DestroyUser(name string) error { + return hsdb.Write(func(tx *gorm.DB) error { + return DestroyUser(tx, name) + }) +} + // DestroyUser destroys a User. Returns error if the User does // not exist or if there are nodes associated with it. -func (hsdb *HSDatabase) DestroyUser(name string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - user, err := hsdb.getUser(name) +func DestroyUser(tx *gorm.DB, name string) error { + user, err := GetUser(tx, name) if err != nil { return ErrUserNotFound } - nodes, err := hsdb.listNodesByUser(name) + nodes, err := ListNodesByUser(tx, name) if err != nil { return err } @@ -61,32 +67,35 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { return ErrUserStillHasNodes } - keys, err := hsdb.listPreAuthKeys(name) + keys, err := ListPreAuthKeys(tx, name) if err != nil { return err } for _, key := range keys { - err = hsdb.destroyPreAuthKey(key) + err = DestroyPreAuthKey(tx, key) if err != nil { return err } } - if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil { + if result := tx.Unscoped().Delete(&user); result.Error != nil { return result.Error } return nil } +func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { + return hsdb.Write(func(tx *gorm.DB) error { + return RenameUser(tx, oldName, newName) + }) +} + // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. -func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func RenameUser(tx *gorm.DB, oldName, newName string) error { var err error - oldUser, err := hsdb.getUser(oldName) + oldUser, err := GetUser(tx, oldName) if err != nil { return err } @@ -94,7 +103,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { if err != nil { return err } - _, err = hsdb.getUser(newName) + _, err = GetUser(tx, newName) if err == nil { return ErrUserExists } @@ -104,24 +113,22 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { oldUser.Name = newName - if result := hsdb.db.Save(&oldUser); result.Error != nil { + if result := tx.Save(&oldUser); result.Error != nil { return result.Error } return nil } -// GetUser fetches a user by name. func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getUser(name) + return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { + return GetUser(rx, name) + }) } -func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { +func GetUser(tx *gorm.DB, name string) (*types.User, error) { user := types.User{} - if result := hsdb.db.First(&user, "name = ?", name); errors.Is( + if result := tx.First(&user, "name = ?", name); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -131,17 +138,16 @@ func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { return &user, nil } -// ListUsers gets all the existing users. func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listUsers() + return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) { + return ListUsers(rx) + }) } -func (hsdb *HSDatabase) listUsers() ([]types.User, error) { +// ListUsers gets all the existing users. +func ListUsers(tx *gorm.DB) ([]types.User, error) { users := []types.User{} - if err := hsdb.db.Find(&users).Error; err != nil { + if err := tx.Find(&users).Error; err != nil { return nil, err } @@ -149,46 +155,42 @@ func (hsdb *HSDatabase) listUsers() ([]types.User, error) { } // ListNodesByUser gets all the nodes in a given user. -func (hsdb *HSDatabase) ListNodesByUser(name string) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listNodesByUser(name) -} - -func (hsdb *HSDatabase) listNodesByUser(name string) (types.Nodes, error) { +func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) { err := util.CheckForFQDNRules(name) if err != nil { return nil, err } - user, err := hsdb.getUser(name) + user, err := GetUser(tx, name) if err != nil { return nil, err } nodes := types.Nodes{} - if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { + if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } -// AssignNodeToUser assigns a Node to a user. func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + return hsdb.Write(func(tx *gorm.DB) error { + return AssignNodeToUser(tx, node, username) + }) +} +// AssignNodeToUser assigns a Node to a user. +func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error { err := util.CheckForFQDNRules(username) if err != nil { return err } - user, err := hsdb.getUser(username) + user, err := GetUser(tx, username) if err != nil { return err } node.User = *user - if result := hsdb.db.Save(&node); result.Error != nil { + if result := tx.Save(&node); result.Error != nil { return result.Error } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 1ca3b49..b36e861 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -36,7 +36,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { err = db.DestroyUser("test") c.Assert(err, check.IsNil) - result := db.db.Preload("User").First(&pak, "key = ?", pak.Key) + result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key) // destroying a user also deletes all associated preauthkeys c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) @@ -53,7 +53,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) err = db.DestroyUser("test") c.Assert(err, check.Equals, ErrUserStillHasNodes) @@ -105,7 +105,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) c.Assert(node.UserID, check.Equals, oldUser.ID) err = db.AssignNodeToUser(&node, newUser.Name) diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index ad325c7..52a63e9 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -211,7 +211,7 @@ func DERPProbeHandler( // The initial implementation is here https://github.com/tailscale/tailscale/pull/1406 // They have a cache, but not clear if that is really necessary at Headscale, uh, scale. // An example implementation is found here https://derp.tailscale.com/bootstrap-dns -// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL +// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL. func DERPBootstrapDNSHandler( derpMap *tailcfg.DERPMap, ) func(http.ResponseWriter, *http.Request) { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index ffd3a57..c12ba73 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -8,11 +8,13 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -136,12 +138,14 @@ func (api headscaleV1APIServer) ExpirePreAuthKey( ctx context.Context, request *v1.ExpirePreAuthKeyRequest, ) (*v1.ExpirePreAuthKeyResponse, error) { - preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key) - if err != nil { - return nil, err - } + err := api.h.db.DB.Transaction(func(tx *gorm.DB) error { + preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key) + if err != nil { + return err + } - err = api.h.db.ExpirePreAuthKey(preAuthKey) + return db.ExpirePreAuthKey(tx, preAuthKey) + }) if err != nil { return nil, err } @@ -181,17 +185,31 @@ func (api headscaleV1APIServer) RegisterNode( return nil, err } - node, err := api.h.db.RegisterNodeFromAuthCallback( - api.h.registrationCache, - mkey, - request.GetUser(), - nil, - util.RegisterMethodCLI, - ) + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + return db.RegisterNodeFromAuthCallback( + tx, + api.h.registrationCache, + mkey, + request.GetUser(), + nil, + util.RegisterMethodCLI, + api.h.cfg.IPPrefixes, + ) + }) if err != nil { return nil, err } + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from api.RegisterNode", + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-registernode", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) + } + return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } @@ -217,25 +235,35 @@ func (api headscaleV1APIServer) SetTags( ctx context.Context, request *v1.SetTagsRequest, ) (*v1.SetTagsResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) - if err != nil { - return nil, err - } - for _, tag := range request.GetTags() { err := validateTag(tag) if err != nil { - return &v1.SetTagsResponse{ - Node: nil, - }, status.Error(codes.InvalidArgument, err.Error()) + return nil, err } } - err = api.h.db.SetTags(node, request.GetTags()) + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + err := db.SetTags(tx, request.GetNodeId(), request.GetTags()) + if err != nil { + return nil, err + } + + return db.GetNodeByID(tx, request.GetNodeId()) + }) if err != nil { return &v1.SetTagsResponse{ Node: nil, - }, status.Error(codes.Internal, err.Error()) + }, status.Error(codes.InvalidArgument, err.Error()) + } + + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from api.SetTags", + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) } log.Trace(). @@ -270,11 +298,21 @@ func (api headscaleV1APIServer) DeleteNode( err = api.h.db.DeleteNode( node, + api.h.nodeNotifier.ConnectedMap(), ) if err != nil { return nil, err } + stateUpdate := types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname) + api.h.nodeNotifier.NotifyAll(ctx, stateUpdate) + } + return &v1.DeleteNodeResponse{}, nil } @@ -282,17 +320,38 @@ func (api headscaleV1APIServer) ExpireNode( ctx context.Context, request *v1.ExpireNodeRequest, ) (*v1.ExpireNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + now := time.Now() + + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + db.NodeSetExpiry( + tx, + request.GetNodeId(), + now, + ) + + return db.GetNodeByID(tx, request.GetNodeId()) + }) if err != nil { return nil, err } - now := time.Now() + selfUpdate := types.StateUpdate{ + Type: types.StateSelfUpdate, + ChangeNodes: types.Nodes{node}, + } + if selfUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname) + api.h.nodeNotifier.NotifyByMachineKey( + ctx, + selfUpdate, + node.MachineKey) + } - api.h.db.NodeSetExpiry( - node, - now, - ) + stateUpdate := types.StateUpdateExpire(node.ID, now) + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) + } log.Trace(). Str("node", node.Hostname). @@ -306,17 +365,30 @@ func (api headscaleV1APIServer) RenameNode( ctx context.Context, request *v1.RenameNodeRequest, ) (*v1.RenameNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + err := db.RenameNode( + tx, + request.GetNodeId(), + request.GetNewName(), + ) + if err != nil { + return nil, err + } + + return db.GetNodeByID(tx, request.GetNodeId()) + }) if err != nil { return nil, err } - err = api.h.db.RenameNode( - node, - request.GetNewName(), - ) - if err != nil { - return nil, err + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from api.RenameNode", + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) } log.Trace(). @@ -331,8 +403,11 @@ func (api headscaleV1APIServer) ListNodes( ctx context.Context, request *v1.ListNodesRequest, ) (*v1.ListNodesResponse, error) { + isConnected := api.h.nodeNotifier.ConnectedMap() if request.GetUser() != "" { - nodes, err := api.h.db.ListNodesByUser(request.GetUser()) + nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) { + return db.ListNodesByUser(rx, request.GetUser()) + }) if err != nil { return nil, err } @@ -343,7 +418,7 @@ func (api headscaleV1APIServer) ListNodes( // Populate the online field based on // currently connected nodes. - resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) + resp.Online = isConnected[node.MachineKey] response[index] = resp } @@ -362,10 +437,10 @@ func (api headscaleV1APIServer) ListNodes( // Populate the online field based on // currently connected nodes. - resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) + resp.Online = isConnected[node.MachineKey] validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( - &node, + node, ) resp.InvalidTags = invalidTags resp.ValidTags = validTags @@ -396,7 +471,9 @@ func (api headscaleV1APIServer) GetRoutes( ctx context.Context, request *v1.GetRoutesRequest, ) (*v1.GetRoutesResponse, error) { - routes, err := api.h.db.GetRoutes() + routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) { + return db.GetRoutes(rx) + }) if err != nil { return nil, err } @@ -410,11 +487,19 @@ func (api headscaleV1APIServer) EnableRoute( ctx context.Context, request *v1.EnableRouteRequest, ) (*v1.EnableRouteResponse, error) { - err := api.h.db.EnableRoute(request.GetRouteId()) + update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.EnableRoute(tx, request.GetRouteId()) + }) if err != nil { return nil, err } + if update != nil && update.Valid() { + ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown") + api.h.nodeNotifier.NotifyAll( + ctx, *update) + } + return &v1.EnableRouteResponse{}, nil } @@ -422,11 +507,19 @@ func (api headscaleV1APIServer) DisableRoute( ctx context.Context, request *v1.DisableRouteRequest, ) (*v1.DisableRouteResponse, error) { - err := api.h.db.DisableRoute(request.GetRouteId()) + isConnected := api.h.nodeNotifier.ConnectedMap() + update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.DisableRoute(tx, request.GetRouteId(), isConnected) + }) if err != nil { return nil, err } + if update != nil && update.Valid() { + ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown") + api.h.nodeNotifier.NotifyAll(ctx, *update) + } + return &v1.DisableRouteResponse{}, nil } @@ -453,11 +546,19 @@ func (api headscaleV1APIServer) DeleteRoute( ctx context.Context, request *v1.DeleteRouteRequest, ) (*v1.DeleteRouteResponse, error) { - err := api.h.db.DeleteRoute(request.GetRouteId()) + isConnected := api.h.nodeNotifier.ConnectedMap() + update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.DeleteRoute(tx, request.GetRouteId(), isConnected) + }) if err != nil { return nil, err } + if update != nil && update.Valid() { + ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown") + api.h.nodeNotifier.NotifyWithIgnore(ctx, *update) + } + return &v1.DeleteRouteResponse{}, nil } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 9998f12..df0f4d9 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -272,6 +272,7 @@ func (m *Mapper) LiteMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, pol *policy.ACLPolicy, + messages ...string, ) ([]byte, error) { resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version) if err != nil { @@ -290,7 +291,7 @@ func (m *Mapper) LiteMapResponse( resp.PacketFilter = policy.ReduceFilterRules(node, rules) resp.SSHPolicy = sshPolicy - return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress) + return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...) } func (m *Mapper) KeepAliveResponse( @@ -392,9 +393,7 @@ func (m *Mapper) PeerChangedPatchResponse( } if patches, ok := m.patches[uint64(change.NodeID)]; ok { - patches := append(patches, p) - - m.patches[uint64(change.NodeID)] = patches + m.patches[uint64(change.NodeID)] = append(patches, p) } else { m.patches[uint64(change.NodeID)] = []patch{p} } @@ -470,6 +469,8 @@ func (m *Mapper) marshalMapResponse( switch { case resp.Peers != nil && len(resp.Peers) > 0: responseType = "full" + case isSelfUpdate(messages...): + responseType = "self" case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil: responseType = "lite" case resp.PeersChanged != nil && len(resp.PeersChanged) > 0: @@ -668,3 +669,13 @@ func appendPeerChanges( return nil } + +func isSelfUpdate(messages ...string) bool { + for _, message := range messages { + if strings.Contains(message, types.SelfUpdateIdentifier) { + return true + } + } + + return false +} diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index e213a95..c10da4d 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -72,7 +72,7 @@ func tailNode( } var derp string - if node.Hostinfo.NetInfo != nil { + if node.Hostinfo != nil && node.Hostinfo.NetInfo != nil { derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP) } else { derp = "127.3.3.40:0" // Zero means disconnected or unknown. diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index 77e8b19..2384a40 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -1,6 +1,7 @@ package notifier import ( + "context" "fmt" "strings" "sync" @@ -12,26 +13,30 @@ import ( ) type Notifier struct { - l sync.RWMutex - nodes map[string]chan<- types.StateUpdate + l sync.RWMutex + nodes map[string]chan<- types.StateUpdate + connected map[key.MachinePublic]bool } func NewNotifier() *Notifier { - return &Notifier{} + return &Notifier{ + nodes: make(map[string]chan<- types.StateUpdate), + connected: make(map[key.MachinePublic]bool), + } } func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) { log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node") - defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to add node") + defer log.Trace(). + Caller(). + Str("key", machineKey.ShortString()). + Msg("releasing lock to add node") n.l.Lock() defer n.l.Unlock() - if n.nodes == nil { - n.nodes = make(map[string]chan<- types.StateUpdate) - } - n.nodes[machineKey.String()] = c + n.connected[machineKey] = true log.Trace(). Str("machine_key", machineKey.ShortString()). @@ -41,16 +46,20 @@ func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpd func (n *Notifier) RemoveNode(machineKey key.MachinePublic) { log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node") - defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to remove node") + defer log.Trace(). + Caller(). + Str("key", machineKey.ShortString()). + Msg("releasing lock to remove node") n.l.Lock() defer n.l.Unlock() - if n.nodes == nil { + if len(n.nodes) == 0 { return } delete(n.nodes, machineKey.String()) + n.connected[machineKey] = false log.Trace(). Str("machine_key", machineKey.ShortString()). @@ -64,23 +73,28 @@ func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool { n.l.RLock() defer n.l.RUnlock() - if _, ok := n.nodes[machineKey.String()]; ok { - return true - } - - return false + return n.connected[machineKey] } -func (n *Notifier) NotifyAll(update types.StateUpdate) { - n.NotifyWithIgnore(update) +// TODO(kradalby): This returns a pointer and can be dangerous. +func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool { + return n.connected } -func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) { +func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) { + n.NotifyWithIgnore(ctx, update) +} + +func (n *Notifier) NotifyWithIgnore( + ctx context.Context, + update types.StateUpdate, + ignore ...string, +) { log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") defer log.Trace(). Caller(). Interface("type", update.Type). - Msg("releasing lock, finished notifing") + Msg("releasing lock, finished notifying") n.l.RLock() defer n.l.RUnlock() @@ -90,23 +104,58 @@ func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) continue } - log.Trace().Caller().Str("machine", key).Strs("ignoring", ignore).Msg("sending update") - c <- update + select { + case <-ctx.Done(): + log.Error(). + Err(ctx.Err()). + Str("mkey", key). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update not sent, context cancelled") + + return + case c <- update: + log.Trace(). + Str("mkey", key). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update successfully sent on chan") + } } } -func (n *Notifier) NotifyByMachineKey(update types.StateUpdate, mKey key.MachinePublic) { +func (n *Notifier) NotifyByMachineKey( + ctx context.Context, + update types.StateUpdate, + mKey key.MachinePublic, +) { log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") defer log.Trace(). Caller(). Interface("type", update.Type). - Msg("releasing lock, finished notifing") + Msg("releasing lock, finished notifying") n.l.RLock() defer n.l.RUnlock() if c, ok := n.nodes[mKey.String()]; ok { - c <- update + select { + case <-ctx.Done(): + log.Error(). + Err(ctx.Err()). + Str("mkey", mKey.String()). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update not sent, context cancelled") + + return + case c <- update: + log.Trace(). + Str("mkey", mKey.String()). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update successfully sent on chan") + } } } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 568519f..a0fc931 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -20,6 +20,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" + "gorm.io/gorm" "tailscale.com/types/key" ) @@ -492,7 +493,7 @@ func (h *Headscale) validateNodeForOIDCCallback( Str("node", node.Hostname). Msg("node already registered, reauthenticating") - err := h.db.NodeSetExpiry(node, expiry) + err := h.db.NodeSetExpiry(node.ID, expiry) if err != nil { util.LogErr(err, "Failed to refresh node") http.Error( @@ -536,6 +537,12 @@ func (h *Headscale) validateNodeForOIDCCallback( util.LogErr(err, "Failed to write response") } + stateUpdate := types.StateUpdateExpire(node.ID, expiry) + if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na") + h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) + } + return nil, true, nil } @@ -613,14 +620,22 @@ func (h *Headscale) registerNodeForOIDCCallback( machineKey *key.MachinePublic, expiry time.Time, ) error { - if _, err := h.db.RegisterNodeFromAuthCallback( - // TODO(kradalby): find a better way to use the cache across modules - h.registrationCache, - *machineKey, - user.Name, - &expiry, - util.RegisterMethodOIDC, - ); err != nil { + if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + if _, err := db.RegisterNodeFromAuthCallback( + // TODO(kradalby): find a better way to use the cache across modules + tx, + h.registrationCache, + *machineKey, + user.Name, + &expiry, + util.RegisterMethodOIDC, + h.cfg.IPPrefixes, + ); err != nil { + return err + } + + return nil + }); err != nil { util.LogErr(err, "could not register node") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 1dd664c..2ccc56b 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -905,32 +905,39 @@ func (pol *ACLPolicy) TagsOfNode( validTags := make([]string, 0) invalidTags := make([]string, 0) + // TODO(kradalby): Why is this sometimes nil? coming from tailNode? + if node == nil { + return validTags, invalidTags + } + validTagMap := make(map[string]bool) invalidTagMap := make(map[string]bool) - for _, tag := range node.Hostinfo.RequestTags { - owners, err := expandOwnersFromTag(pol, tag) - if errors.Is(err, ErrInvalidTag) { - invalidTagMap[tag] = true + if node.Hostinfo != nil { + for _, tag := range node.Hostinfo.RequestTags { + owners, err := expandOwnersFromTag(pol, tag) + if errors.Is(err, ErrInvalidTag) { + invalidTagMap[tag] = true - continue - } - var found bool - for _, owner := range owners { - if node.User.Name == owner { - found = true + continue + } + var found bool + for _, owner := range owners { + if node.User.Name == owner { + found = true + } + } + if found { + validTagMap[tag] = true + } else { + invalidTagMap[tag] = true } } - if found { - validTagMap[tag] = true - } else { - invalidTagMap[tag] = true + for tag := range invalidTagMap { + invalidTags = append(invalidTags, tag) + } + for tag := range validTagMap { + validTags = append(validTags, tag) } - } - for tag := range invalidTagMap { - invalidTags = append(invalidTags, tag) - } - for tag := range validTagMap { - validTags = append(validTags, tag) } return validTags, invalidTags diff --git a/hscontrol/poll.go b/hscontrol/poll.go index c867f26..f00152d 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -4,12 +4,15 @@ import ( "context" "fmt" "net/http" + "strings" "time" + "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" xslices "golang.org/x/exp/slices" + "gorm.io/gorm" "tailscale.com/tailcfg" ) @@ -128,10 +131,14 @@ func (h *Headscale) handlePoll( if h.ACLPolicy != nil { // update routes with peer information - err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) + update, err := h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) if err != nil { logErr(err, "Error running auto approved routes") } + + if update != nil { + sendUpdate = true + } } } @@ -146,7 +153,7 @@ func (h *Headscale) handlePoll( } if sendUpdate { - if err := h.db.NodeSave(node); err != nil { + if err := h.db.DB.Save(node).Error; err != nil { logErr(err, "Failed to persist/update node in the database") http.Error(writer, "", http.StatusInternalServerError) @@ -161,7 +168,9 @@ func (h *Headscale) handlePoll( Message: "called from handlePoll -> update -> new hostinfo", } if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-hostinfochange", node.Hostname) h.nodeNotifier.NotifyWithIgnore( + ctx, stateUpdate, node.MachineKey.String()) } @@ -174,7 +183,9 @@ func (h *Headscale) handlePoll( ChangeNodes: types.Nodes{node}, } if selfUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", node.Hostname) h.nodeNotifier.NotifyByMachineKey( + ctx, selfUpdate, node.MachineKey) } @@ -183,7 +194,7 @@ func (h *Headscale) handlePoll( } } - if err := h.db.NodeSave(node); err != nil { + if err := h.db.DB.Save(node).Error; err != nil { logErr(err, "Failed to persist/update node in the database") http.Error(writer, "", http.StatusInternalServerError) @@ -195,7 +206,9 @@ func (h *Headscale) handlePoll( ChangePatches: []*tailcfg.PeerChange{&change}, } if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname) h.nodeNotifier.NotifyWithIgnore( + ctx, stateUpdate, node.MachineKey.String()) } @@ -251,7 +264,7 @@ func (h *Headscale) handlePoll( } } - if err := h.db.NodeSave(node); err != nil { + if err := h.db.DB.Save(node).Error; err != nil { logErr(err, "Failed to persist/update node in the database") http.Error(writer, "", http.StatusInternalServerError) @@ -288,7 +301,10 @@ func (h *Headscale) handlePoll( // update ACLRules with peer informations (to update server tags if necessary) if h.ACLPolicy != nil { // update routes with peer information - err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) + // This state update is ignored as it will be sent + // as part of the whole node + // TODO(kradalby): figure out if that is actually correct + _, err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) if err != nil { logErr(err, "Error running auto approved routes") } @@ -324,11 +340,17 @@ func (h *Headscale) handlePoll( Message: "called from handlePoll -> new node added", } if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-newnode-peers", node.Hostname) h.nodeNotifier.NotifyWithIgnore( + ctx, stateUpdate, node.MachineKey.String()) } + if len(node.Routes) > 0 { + go h.pollFailoverRoutes(logErr, "new node", node) + } + // Set up the client stream h.pollNetMapStreamWG.Add(1) defer h.pollNetMapStreamWG.Done() @@ -346,15 +368,9 @@ func (h *Headscale) handlePoll( keepAliveTicker := time.NewTicker(keepAliveInterval) - ctx = context.WithValue(ctx, nodeNameContextKey, node.Hostname) - - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(context.WithValue(ctx, nodeNameContextKey, node.Hostname)) defer cancel() - if len(node.Routes) > 0 { - go h.db.EnsureFailoverRouteIsAvailable(node) - } - for { logInfo("Waiting for update on stream channel") select { @@ -403,6 +419,7 @@ func (h *Headscale) handlePoll( return } + startMapResp := time.Now() switch update.Type { case types.StateFullUpdate: logInfo("Sending Full MapResponse") @@ -411,6 +428,7 @@ func (h *Headscale) handlePoll( case types.StatePeerChanged: logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message)) + isConnectedMap := h.nodeNotifier.ConnectedMap() for _, node := range update.ChangeNodes { // If a node is not reported to be online, it might be // because the value is outdated, check with the notifier. @@ -418,7 +436,7 @@ func (h *Headscale) handlePoll( // this might be because it has announced itself, but not // reached the stage to actually create the notifier channel. if node.IsOnline != nil && !*node.IsOnline { - isOnline := h.nodeNotifier.IsConnected(node.MachineKey) + isOnline := isConnectedMap[node.MachineKey] node.IsOnline = &isOnline } } @@ -434,7 +452,7 @@ func (h *Headscale) handlePoll( if len(update.ChangeNodes) == 1 { logInfo("Sending SelfUpdate MapResponse") node = update.ChangeNodes[0] - data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy) + data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy, types.SelfUpdateIdentifier) } else { logInfo("SelfUpdate contained too many nodes, this is likely a bug in the code, please report.") } @@ -449,8 +467,11 @@ func (h *Headscale) handlePoll( return } + log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startMapResp).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished making map response") + // Only send update if there is change if data != nil { + startWrite := time.Now() _, err = writer.Write(data) if err != nil { logErr(err, "Could not write the map response") @@ -468,6 +489,7 @@ func (h *Headscale) handlePoll( return } + log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished writing mapresp to node") log.Info(). Caller(). @@ -487,7 +509,7 @@ func (h *Headscale) handlePoll( go h.updateNodeOnlineStatus(false, node) // Failover the node's routes if any. - go h.db.FailoverNodeRoutesWithNotify(node) + go h.pollFailoverRoutes(logErr, "node closing connection", node) // The connection has been closed, so we can stop polling. return @@ -500,6 +522,22 @@ func (h *Headscale) handlePoll( } } +func (h *Headscale) pollFailoverRoutes(logErr func(error, string), where string, node *types.Node) { + update, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.EnsureFailoverRouteIsAvailable(tx, h.nodeNotifier.ConnectedMap(), node) + }) + if err != nil { + logErr(err, fmt.Sprintf("failed to ensure failover routes, %s", where)) + + return + } + + if update != nil && !update.Empty() && update.Valid() { + ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname) + h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.MachineKey.String()) + } +} + // updateNodeOnlineStatus records the last seen status of a node and notifies peers // about change in their online/offline status. // It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged. @@ -519,10 +557,13 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) { }, } if statusUpdate.Valid() { - h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String()) + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname) + h.nodeNotifier.NotifyWithIgnore(ctx, statusUpdate, node.MachineKey.String()) } - err := h.db.UpdateLastSeen(node) + err := h.db.DB.Transaction(func(tx *gorm.DB) error { + return db.UpdateLastSeen(tx, node.ID, *node.LastSeen) + }) if err != nil { log.Error().Err(err).Msg("Cannot update node LastSeen") diff --git a/hscontrol/poll_noise.go b/hscontrol/poll_noise.go index 675836a..53b1d47 100644 --- a/hscontrol/poll_noise.go +++ b/hscontrol/poll_noise.go @@ -13,7 +13,7 @@ import ( ) const ( - MinimumCapVersion tailcfg.CapabilityVersion = 56 + MinimumCapVersion tailcfg.CapabilityVersion = 58 ) // NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index e38d8e3..d45f9d4 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -1,15 +1,19 @@ package types import ( + "context" "database/sql/driver" "encoding/json" "errors" "fmt" "net/netip" + "time" "tailscale.com/tailcfg" ) +const SelfUpdateIdentifier = "self-update" + var ErrCannotParsePrefix = errors.New("cannot parse prefix") type IPPrefix netip.Prefix @@ -160,3 +164,37 @@ func (su *StateUpdate) Valid() bool { return true } + +// Empty reports if there are any updates in the StateUpdate. +func (su *StateUpdate) Empty() bool { + switch su.Type { + case StatePeerChanged: + return len(su.ChangeNodes) == 0 + case StatePeerChangedPatch: + return len(su.ChangePatches) == 0 + case StatePeerRemoved: + return len(su.Removed) == 0 + } + + return false +} + +func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate { + return StateUpdate{ + Type: StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: tailcfg.NodeID(nodeID), + KeyExpiry: &expiry, + }, + }, + } +} + +func NotifyCtx(ctx context.Context, origin, hostname string) context.Context { + ctx2, _ := context.WithTimeout( + context.WithValue(context.WithValue(ctx, "hostname", hostname), "origin", origin), + 3*time.Second, + ) + return ctx2 +} diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 7f6b40e..0b8324f 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -2,7 +2,6 @@ package types import ( "strconv" - "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" @@ -22,12 +21,13 @@ type User struct { func (n *User) TailscaleUser() *tailcfg.User { user := tailcfg.User{ - ID: tailcfg.UserID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, + ID: tailcfg.UserID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + // TODO(kradalby): See if we can fill in Gravatar here ProfilePicURL: "", Logins: []tailcfg.LoginID{}, - Created: time.Time{}, + Created: n.CreatedAt, } return &user @@ -35,9 +35,10 @@ func (n *User) TailscaleUser() *tailcfg.User { func (n *User) TailscaleLogin() *tailcfg.Login { login := tailcfg.Login{ - ID: tailcfg.LoginID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, + ID: tailcfg.LoginID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + // TODO(kradalby): See if we can fill in Gravatar here ProfilePicURL: "", } diff --git a/integration/cli_test.go b/integration/cli_test.go index d2d741e..e6190fb 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1240,7 +1240,7 @@ func TestNodeRenameCommand(t *testing.T) { assert.Contains(t, listAll[4].GetGivenName(), "node-5") for idx := 0; idx < 3; idx++ { - _, err := headscale.Execute( + res, err := headscale.Execute( []string{ "headscale", "nodes", @@ -1251,6 +1251,8 @@ func TestNodeRenameCommand(t *testing.T) { }, ) assert.Nil(t, err) + + assert.Contains(t, res, "Node renamed") } var listAllAfterRename []v1.Node diff --git a/integration/general_test.go b/integration/general_test.go index 15c3a72..5c98cd2 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -26,6 +26,8 @@ func TestPingAllByIP(t *testing.T) { assertNoErr(t, err) defer scenario.Shutdown() + // TODO(kradalby): it does not look like the user thing works, only second + // get created? maybe only when many? spec := map[string]int{ "user1": len(MustTestVersions), "user2": len(MustTestVersions), @@ -321,7 +323,12 @@ func TestTaildrop(t *testing.T) { t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err) } } - curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"} + curlCommand := []string{ + "curl", + "--unix-socket", + "/var/run/tailscale/tailscaled.sock", + "http://local-tailscaled.sock/localapi/v0/file-targets", + } err = retry(10, 1*time.Second, func() error { result, _, err := client.Execute(curlCommand) if err != nil { @@ -338,13 +345,23 @@ func TestTaildrop(t *testing.T) { for _, ft := range fts { ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name) } - return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr) + return fmt.Errorf( + "client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", + client.Hostname(), + len(fts), + len(allClients)-1, + ftStr, + ) } return err }) if err != nil { - t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err) + t.Errorf( + "failed to query localapi for filetarget on %s, err: %s", + client.Hostname(), + err, + ) } } @@ -434,72 +451,6 @@ func TestTaildrop(t *testing.T) { } } -func TestResolveMagicDNS(t *testing.T) { - IntegrationSkip(t) - t.Parallel() - - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "magicdns1": len(MustTestVersions), - "magicdns2": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns")) - assertNoErrHeadscaleEnv(t, err) - - allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) - - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) - - // Poor mans cache - _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) - - _, err = scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) - - for _, client := range allClients { - for _, peer := range allClients { - // It is safe to ignore this error as we handled it when caching it - peerFQDN, _ := peer.FQDN() - - command := []string{ - "tailscale", - "ip", peerFQDN, - } - result, _, err := client.Execute(command) - if err != nil { - t.Fatalf( - "failed to execute resolve/ip command %s from %s: %s", - peerFQDN, - client.Hostname(), - err, - ) - } - - ips, err := peer.IPs() - if err != nil { - t.Fatalf( - "failed to get ips for %s: %s", - peer.Hostname(), - err, - ) - } - - for _, ip := range ips { - if !strings.Contains(result, ip.String()) { - t.Fatalf("ip %s is not found in \n%s\n", ip.String(), result) - } - } - } - } -} - func TestExpireNode(t *testing.T) { IntegrationSkip(t) t.Parallel() @@ -545,7 +496,7 @@ func TestExpireNode(t *testing.T) { // TODO(kradalby): This is Headscale specific and would not play nicely // with other implementations of the ControlServer interface result, err := headscale.Execute([]string{ - "headscale", "nodes", "expire", "--identifier", "0", "--output", "json", + "headscale", "nodes", "expire", "--identifier", "1", "--output", "json", }) assertNoErr(t, err) @@ -576,16 +527,38 @@ func TestExpireNode(t *testing.T) { assertNotNil(t, peerStatus.Expired) assert.NotNil(t, peerStatus.KeyExpiry) - t.Logf("node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) + t.Logf( + "node %q should have a key expire before %s, was %s", + peerStatus.HostName, + now.String(), + peerStatus.KeyExpiry, + ) if peerStatus.KeyExpiry != nil { - assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) + assert.Truef( + t, + peerStatus.KeyExpiry.Before(now), + "node %q should have a key expire before %s, was %s", + peerStatus.HostName, + now.String(), + peerStatus.KeyExpiry, + ) } - assert.Truef(t, peerStatus.Expired, "node %q should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired) + assert.Truef( + t, + peerStatus.Expired, + "node %q should be expired, expired is %v", + peerStatus.HostName, + peerStatus.Expired, + ) _, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()}) if !strings.Contains(stderr, "node key has expired") { - t.Errorf("expected to be unable to ping expired host %q from %q", node.GetName(), client.Hostname()) + t.Errorf( + "expected to be unable to ping expired host %q from %q", + node.GetName(), + client.Hostname(), + ) } } else { t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey) @@ -597,7 +570,7 @@ func TestExpireNode(t *testing.T) { // NeedsLogin means that the node has understood that it is no longer // valid. - assert.Equal(t, "NeedsLogin", status.BackendState) + assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName) } } } @@ -690,7 +663,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) { assert.Truef( t, lastSeen.After(lastSeenThreshold), - "lastSeen (%v) was not %s after the threshold (%v)", + "node (%s) lastSeen (%v) was not %s after the threshold (%v)", + node.GetName(), lastSeen, keepAliveInterval, lastSeenThreshold, diff --git a/integration/route_test.go b/integration/route_test.go index 741ba24..75296fd 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -88,9 +88,9 @@ func TestEnablingRoutes(t *testing.T) { assert.Len(t, routes, 3) for _, route := range routes { - assert.Equal(t, route.GetAdvertised(), true) - assert.Equal(t, route.GetEnabled(), false) - assert.Equal(t, route.GetIsPrimary(), false) + assert.Equal(t, true, route.GetAdvertised()) + assert.Equal(t, false, route.GetEnabled()) + assert.Equal(t, false, route.GetIsPrimary()) } // Verify that no routes has been sent to the client, @@ -135,9 +135,9 @@ func TestEnablingRoutes(t *testing.T) { assert.Len(t, enablingRoutes, 3) for _, route := range enablingRoutes { - assert.Equal(t, route.GetAdvertised(), true) - assert.Equal(t, route.GetEnabled(), true) - assert.Equal(t, route.GetIsPrimary(), true) + assert.Equal(t, true, route.GetAdvertised()) + assert.Equal(t, true, route.GetEnabled()) + assert.Equal(t, true, route.GetIsPrimary()) } time.Sleep(5 * time.Second) @@ -191,6 +191,8 @@ func TestEnablingRoutes(t *testing.T) { }) assertNoErr(t, err) + time.Sleep(5 * time.Second) + var disablingRoutes []*v1.Route err = executeAndUnmarshal( headscale, @@ -209,16 +211,14 @@ func TestEnablingRoutes(t *testing.T) { assert.Equal(t, true, route.GetAdvertised()) if route.GetId() == routeToBeDisabled.GetId() { - assert.Equal(t, route.GetEnabled(), false) - assert.Equal(t, route.GetIsPrimary(), false) + assert.Equal(t, false, route.GetEnabled()) + assert.Equal(t, false, route.GetIsPrimary()) } else { - assert.Equal(t, route.GetEnabled(), true) - assert.Equal(t, route.GetIsPrimary(), true) + assert.Equal(t, true, route.GetEnabled()) + assert.Equal(t, true, route.GetIsPrimary()) } } - time.Sleep(5 * time.Second) - // Verify that the clients can see the new routes for _, client := range allClients { status, err := client.Status() @@ -294,7 +294,7 @@ func TestHASubnetRouterFailover(t *testing.T) { // advertise HA route on node 1 and 2 // ID 1 will be primary // ID 2 will be secondary - for _, client := range allClients { + for _, client := range allClients[:2] { status, err := client.Status() assertNoErr(t, err) @@ -306,6 +306,8 @@ func TestHASubnetRouterFailover(t *testing.T) { } _, _, err = client.Execute(command) assertNoErrf(t, "failed to advertise route: %s", err) + } else { + t.Fatalf("failed to find route for Node %s (id: %s)", status.Self.HostName, status.Self.ID) } } @@ -328,6 +330,8 @@ func TestHASubnetRouterFailover(t *testing.T) { assertNoErr(t, err) assert.Len(t, routes, 2) + t.Logf("initial routes %#v", routes) + for _, route := range routes { assert.Equal(t, true, route.GetAdvertised()) assert.Equal(t, false, route.GetEnabled()) @@ -644,6 +648,8 @@ func TestHASubnetRouterFailover(t *testing.T) { assertNoErr(t, err) assert.Len(t, routesAfterDisabling1, 2) + t.Logf("routes after disabling1 %#v", routesAfterDisabling1) + // Node 1 is not primary assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised()) assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled()) diff --git a/integration/scenario.go b/integration/scenario.go index c11af72..16ec6f4 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -56,8 +56,8 @@ var ( "1.44": true, // CapVer: 63 "1.42": true, // CapVer: 61 "1.40": true, // CapVer: 61 - "1.38": true, // CapVer: 58 - "1.36": true, // Oldest supported version, CapVer: 56 + "1.38": true, // Oldest supported version, CapVer: 58 + "1.36": false, // CapVer: 56 "1.34": false, // CapVer: 51 "1.32": false, // CapVer: 46 "1.30": false, From 00e7550e760b2d3d759471ff55d2b6e2dc81ad2b Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 9 Feb 2024 07:26:41 +0100 Subject: [PATCH 09/13] Add assert func for verifying status, netmap and netcheck (#1723) --- ...egration-v2-TestPingAllByIPPublicDERP.yaml | 67 ++++++++ hscontrol/poll.go | 10 +- hscontrol/types/node_test.go | 108 +++++++++++++ hscontrol/util/test.go | 6 +- integration/auth_oidc_test.go | 4 + integration/auth_web_flow_test.go | 4 + integration/embedded_derp_test.go | 29 ++-- integration/general_test.go | 140 ++++++++++++++++- integration/tailscale.go | 2 + integration/tsic/tsic.go | 34 ++++ integration/utils.go | 148 +++++++++++++++++- 11 files changed, 534 insertions(+), 18 deletions(-) create mode 100644 .github/workflows/test-integration-v2-TestPingAllByIPPublicDERP.yaml diff --git a/.github/workflows/test-integration-v2-TestPingAllByIPPublicDERP.yaml b/.github/workflows/test-integration-v2-TestPingAllByIPPublicDERP.yaml new file mode 100644 index 0000000..18fd341 --- /dev/null +++ b/.github/workflows/test-integration-v2-TestPingAllByIPPublicDERP.yaml @@ -0,0 +1,67 @@ +# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go +# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ + +name: Integration Test v2 - TestPingAllByIPPublicDERP + +on: [pull_request] + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + TestPingAllByIPPublicDERP: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - uses: DeterminateSystems/nix-installer-action@main + - uses: DeterminateSystems/magic-nix-cache-action@main + - uses: satackey/action-docker-layer-caching@main + continue-on-error: true + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v34 + with: + files: | + *.nix + go.* + **/*.go + integration_test/ + config-example.yaml + + - name: Run TestPingAllByIPPublicDERP + uses: Wandalen/wretry.action@master + if: steps.changed-files.outputs.any_changed == 'true' + with: + attempt_limit: 5 + command: | + nix develop --command -- docker run \ + --tty --rm \ + --volume ~/.cache/hs-integration-go:/go \ + --name headscale-test-suite \ + --volume $PWD:$PWD -w $PWD/integration \ + --volume /var/run/docker.sock:/var/run/docker.sock \ + --volume $PWD/control_logs:/tmp/control \ + golang:1 \ + go run gotest.tools/gotestsum@latest -- ./... \ + -failfast \ + -timeout 120m \ + -parallel 1 \ + -run "^TestPingAllByIPPublicDERP$" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: logs + path: "control_logs/*.log" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: pprof + path: "control_logs/*.pprof.tar" diff --git a/hscontrol/poll.go b/hscontrol/poll.go index f00152d..03f52ed 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -201,9 +201,15 @@ func (h *Headscale) handlePoll( return } + // TODO(kradalby): Figure out why patch changes does + // not show up in output from `tailscale debug netmap`. + // stateUpdate := types.StateUpdate{ + // Type: types.StatePeerChangedPatch, + // ChangePatches: []*tailcfg.PeerChange{&change}, + // } stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{&change}, + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, } if stateUpdate.Valid() { ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname) diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index 7e6c984..712a839 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -6,6 +6,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/juanfont/headscale/hscontrol/util" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -366,3 +367,110 @@ func TestPeerChangeFromMapRequest(t *testing.T) { }) } } + +func TestApplyPeerChange(t *testing.T) { + tests := []struct { + name string + nodeBefore Node + change *tailcfg.PeerChange + want Node + }{ + { + name: "hostinfo-and-netinfo-not-exists", + nodeBefore: Node{}, + change: &tailcfg.PeerChange{ + DERPRegion: 1, + }, + want: Node{ + Hostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 1, + }, + }, + }, + }, + { + name: "hostinfo-netinfo-not-exists", + nodeBefore: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + }, + }, + change: &tailcfg.PeerChange{ + DERPRegion: 3, + }, + want: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 3, + }, + }, + }, + }, + { + name: "hostinfo-netinfo-exists-derp-set", + nodeBefore: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 999, + }, + }, + }, + change: &tailcfg.PeerChange{ + DERPRegion: 2, + }, + want: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 2, + }, + }, + }, + }, + { + name: "endpoints-not-set", + nodeBefore: Node{}, + change: &tailcfg.PeerChange{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + want: Node{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + }, + { + name: "endpoints-set", + nodeBefore: Node{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("6.6.6.6:66"), + }, + }, + change: &tailcfg.PeerChange{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + want: Node{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.nodeBefore.ApplyPeerChange(tt.change) + + if diff := cmp.Diff(tt.want, tt.nodeBefore, util.Comparers...); diff != "" { + t.Errorf("Patch unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/util/test.go b/hscontrol/util/test.go index 6d46542..0a23acb 100644 --- a/hscontrol/util/test.go +++ b/hscontrol/util/test.go @@ -15,6 +15,10 @@ var IPComparer = cmp.Comparer(func(x, y netip.Addr) bool { return x.Compare(y) == 0 }) +var AddrPortComparer = cmp.Comparer(func(x, y netip.AddrPort) bool { + return x == y +}) + var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool { return x.String() == y.String() }) @@ -28,5 +32,5 @@ var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool { }) var Comparers []cmp.Option = []cmp.Option{ - IPComparer, PrefixComparer, MkeyComparer, NkeyComparer, DkeyComparer, + IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer, } diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 7a0ed9c..36e74a8 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -83,6 +83,8 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) @@ -140,6 +142,8 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 90ce571..aa589fa 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -53,6 +53,8 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) @@ -90,6 +92,8 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index 3a40749..15ab7ad 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -33,20 +33,23 @@ func TestDERPServerScenario(t *testing.T) { defer scenario.Shutdown() spec := map[string]int{ - "user1": len(MustTestVersions), + "user1": 10, + // "user1": len(MustTestVersions), } - headscaleConfig := map[string]string{} - headscaleConfig["HEADSCALE_DERP_URLS"] = "" - headscaleConfig["HEADSCALE_DERP_SERVER_ENABLED"] = "true" - headscaleConfig["HEADSCALE_DERP_SERVER_REGION_ID"] = "999" - headscaleConfig["HEADSCALE_DERP_SERVER_REGION_CODE"] = "headscale" - headscaleConfig["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP" - headscaleConfig["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478" - headscaleConfig["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key" - // Envknob for enabling DERP debug logs - headscaleConfig["DERP_DEBUG_LOGS"] = "true" - headscaleConfig["DERP_PROBER_DEBUG_LOGS"] = "true" + headscaleConfig := map[string]string{ + "HEADSCALE_DERP_URLS": "", + "HEADSCALE_DERP_SERVER_ENABLED": "true", + "HEADSCALE_DERP_SERVER_REGION_ID": "999", + "HEADSCALE_DERP_SERVER_REGION_CODE": "headscale", + "HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP", + "HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478", + "HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key", + + // Envknob for enabling DERP debug logs + "DERP_DEBUG_LOGS": "true", + "DERP_PROBER_DEBUG_LOGS": "true", + } err = scenario.CreateHeadscaleEnv( spec, @@ -67,6 +70,8 @@ func TestDERPServerScenario(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allHostnames, err := scenario.ListTailscaleClientsFQDNs() assertNoErrListFQDN(t, err) diff --git a/integration/general_test.go b/integration/general_test.go index 5c98cd2..9aae26f 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -33,7 +33,27 @@ func TestPingAllByIP(t *testing.T) { "user2": len(MustTestVersions), } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyip")) + headscaleConfig := map[string]string{ + "HEADSCALE_DERP_URLS": "", + "HEADSCALE_DERP_SERVER_ENABLED": "true", + "HEADSCALE_DERP_SERVER_REGION_ID": "999", + "HEADSCALE_DERP_SERVER_REGION_CODE": "headscale", + "HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP", + "HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478", + "HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key", + + // Envknob for enabling DERP debug logs + "DERP_DEBUG_LOGS": "true", + "DERP_PROBER_DEBUG_LOGS": "true", + } + + err = scenario.CreateHeadscaleEnv(spec, + []tsic.Option{}, + hsic.WithTestName("pingallbyip"), + hsic.WithConfigEnv(headscaleConfig), + hsic.WithTLS(), + hsic.WithHostnameAsServerURL(), + ) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() @@ -45,6 +65,46 @@ func TestPingAllByIP(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) +} + +func TestPingAllByIPPublicDERP(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + scenario, err := NewScenario() + assertNoErr(t, err) + defer scenario.Shutdown() + + spec := map[string]int{ + "user1": len(MustTestVersions), + "user2": len(MustTestVersions), + } + + err = scenario.CreateHeadscaleEnv(spec, + []tsic.Option{}, + hsic.WithTestName("pingallbyippubderp"), + ) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + assertNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) @@ -75,6 +135,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + clientIPs := make(map[TailscaleClient][]netip.Addr) for _, client := range allClients { ips, err := client.IPs() @@ -114,6 +176,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allClients, err = scenario.ListTailscaleClients() assertNoErrListClients(t, err) @@ -265,6 +329,8 @@ func TestPingAllByHostname(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allHostnames, err := scenario.ListTailscaleClientsFQDNs() assertNoErrListFQDN(t, err) @@ -451,6 +517,74 @@ func TestTaildrop(t *testing.T) { } } +func TestResolveMagicDNS(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + scenario, err := NewScenario() + assertNoErr(t, err) + defer scenario.Shutdown() + + spec := map[string]int{ + "magicdns1": len(MustTestVersions), + "magicdns2": len(MustTestVersions), + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns")) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + assertClientsState(t, allClients) + + // Poor mans cache + _, err = scenario.ListTailscaleClientsFQDNs() + assertNoErrListFQDN(t, err) + + _, err = scenario.ListTailscaleClientsIPs() + assertNoErrListClientIPs(t, err) + + for _, client := range allClients { + for _, peer := range allClients { + // It is safe to ignore this error as we handled it when caching it + peerFQDN, _ := peer.FQDN() + + command := []string{ + "tailscale", + "ip", peerFQDN, + } + result, _, err := client.Execute(command) + if err != nil { + t.Fatalf( + "failed to execute resolve/ip command %s from %s: %s", + peerFQDN, + client.Hostname(), + err, + ) + } + + ips, err := peer.IPs() + if err != nil { + t.Fatalf( + "failed to get ips for %s: %s", + peer.Hostname(), + err, + ) + } + + for _, ip := range ips { + if !strings.Contains(result, ip.String()) { + t.Fatalf("ip %s is not found in \n%s\n", ip.String(), result) + } + } + } + } +} + func TestExpireNode(t *testing.T) { IntegrationSkip(t) t.Parallel() @@ -475,6 +609,8 @@ func TestExpireNode(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) @@ -599,6 +735,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) diff --git a/integration/tailscale.go b/integration/tailscale.go index 7187a81..9d6796b 100644 --- a/integration/tailscale.go +++ b/integration/tailscale.go @@ -7,6 +7,7 @@ import ( "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/tsic" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/netcheck" "tailscale.com/types/netmap" ) @@ -28,6 +29,7 @@ type TailscaleClient interface { FQDN() (string, error) Status() (*ipnstate.Status, error) Netmap() (*netmap.NetworkMap, error) + Netcheck() (*netcheck.Report, error) WaitForNeedsLogin() error WaitForRunning() error WaitForPeers(expected int) error diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index c30118d..854d5a7 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -17,6 +17,7 @@ import ( "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/netcheck" "tailscale.com/types/netmap" ) @@ -544,6 +545,29 @@ func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { return &nm, err } +// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance. +func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) { + command := []string{ + "tailscale", + "netcheck", + "--format=json", + } + + result, stderr, err := t.Execute(command) + if err != nil { + fmt.Printf("stderr: %s\n", stderr) + return nil, fmt.Errorf("failed to execute tailscale debug netcheck command: %w", err) + } + + var nm netcheck.Report + err = json.Unmarshal([]byte(result), &nm) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal tailscale netcheck: %w", err) + } + + return &nm, err +} + // FQDN returns the FQDN as a string of the Tailscale instance. func (t *TailscaleInContainer) FQDN() (string, error) { if t.fqdn != "" { @@ -648,12 +672,22 @@ func (t *TailscaleInContainer) WaitForPeers(expected int) error { len(peers), ) } else { + // Verify that the peers of a given node is Online + // has a hostname and a DERP relay. for _, peerKey := range peers { peer := status.Peer[peerKey] if !peer.Online { return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName) } + + if peer.HostName == "" { + return fmt.Errorf("[%s] peer count correct, but %s does not have a Hostname", t.hostname, peer.HostName) + } + + if peer.Relay == "" { + return fmt.Errorf("[%s] peer count correct, but %s does not have a DERP", t.hostname, peer.HostName) + } } } diff --git a/integration/utils.go b/integration/utils.go index e17e18a..ae4441b 100644 --- a/integration/utils.go +++ b/integration/utils.go @@ -7,6 +7,8 @@ import ( "time" "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/assert" + "tailscale.com/util/cmpver" ) const ( @@ -83,7 +85,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts for _, addr := range addrs { err := client.Ping(addr, opts...) if err != nil { - t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err) + t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err) } else { success++ } @@ -120,6 +122,148 @@ func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) return success } +// assertClientsState validates the status and netmap of a list of +// clients for the general case of all to all connectivity. +func assertClientsState(t *testing.T, clients []TailscaleClient) { + t.Helper() + + for _, client := range clients { + assertValidStatus(t, client) + assertValidNetmap(t, client) + assertValidNetcheck(t, client) + } +} + +// assertValidNetmap asserts that the netmap of a client has all +// the minimum required fields set to a known working config for +// the general case. Fields are checked on self, then all peers. +// This test is not suitable for ACL/partial connection tests. +// This test can only be run on clients from 1.56.1. It will +// automatically pass all clients below that and is safe to call +// for all versions. +func assertValidNetmap(t *testing.T, client TailscaleClient) { + t.Helper() + + if cmpver.Compare("1.56.1", client.Version()) <= 0 || + !strings.Contains(client.Hostname(), "unstable") || + !strings.Contains(client.Hostname(), "head") { + return + } + + netmap, err := client.Netmap() + if err != nil { + t.Fatalf("getting netmap for %q: %s", client.Hostname(), err) + } + + assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) + if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { + assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) + } + + assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) + assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) + + assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname()) + + assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) + assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) + assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) + + for _, peer := range netmap.Peers { + assert.NotEqualf(t, "127.3.3.40:0", peer.DERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.DERP()) + + assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) + if hi := peer.Hostinfo(); hi.Valid() { + assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) + + // Netinfo is not always set + assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname()) + if ni := hi.NetInfo(); ni.Valid() { + assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP()) + } + + } + + assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname()) + + assert.Truef(t, *peer.Online(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname()) + + assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname()) + } +} + +// assertValidStatus asserts that the status of a client has all +// the minimum required fields set to a known working config for +// the general case. Fields are checked on self, then all peers. +// This test is not suitable for ACL/partial connection tests. +func assertValidStatus(t *testing.T, client TailscaleClient) { + t.Helper() + status, err := client.Status() + if err != nil { + t.Fatalf("getting status for %q: %s", client.Hostname(), err) + } + + assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname()) + assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname()) + assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname()) + + assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname()) + + // This seem to not appear until version 1.56 + if status.Self.AllowedIPs != nil { + assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname()) + } + + assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname()) + + assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname()) + + assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname()) + + // This isnt really relevant for Self as it wont be in its own socket/wireguard. + // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) + // assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname()) + + for _, peer := range status.Peer { + assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) + assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname()) + assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname()) + + assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname()) + + // This seem to not appear until version 1.56 + if peer.AllowedIPs != nil { + assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname()) + } + + // Addrs does not seem to appear in the status from peers. + // assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname()) + + assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname()) + + assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname()) + assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname()) + + // TODO(kradalby): InEngine is only true when a proper tunnel is set up, + // there might be some interesting stuff to test here in the future. + // assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname()) + } +} + +func assertValidNetcheck(t *testing.T, client TailscaleClient) { + t.Helper() + report, err := client.Netcheck() + if err != nil { + t.Fatalf("getting status for %q: %s", client.Hostname(), err) + } + + assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname()) +} + func isSelfClient(client TailscaleClient, addr string) bool { if addr == client.Hostname() { return true @@ -152,7 +296,7 @@ func isCI() bool { } func dockertestMaxWait() time.Duration { - wait := 60 * time.Second //nolint + wait := 120 * time.Second //nolint if isCI() { wait = 300 * time.Second //nolint From 94b30abf56ae09d82a1541bbc3d19557914f9b27 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 9 Feb 2024 07:27:00 +0100 Subject: [PATCH 10/13] Restructure database config (#1700) --- CHANGELOG.md | 22 +++++---- cmd/headscale/headscale_test.go | 6 ++- config-example.yaml | 31 ++++++------ hscontrol/app.go | 47 +++--------------- hscontrol/db/db.go | 56 ++++++++++++++------- hscontrol/db/routes_test.go | 11 +++-- hscontrol/db/suite_test.go | 12 +++-- hscontrol/suite_test.go | 8 ++- hscontrol/types/common.go | 10 +++- hscontrol/types/config.go | 87 ++++++++++++++++++++++++++------- integration/hsic/config.go | 4 +- 11 files changed, 180 insertions(+), 114 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a7908ea..3adb23b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,16 +34,18 @@ after improving the test harness as part of adopting [#1460](https://github.com/ ### Changes -Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644) -Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484) -SSH support [#1487](https://github.com/juanfont/headscale/pull/1487) -State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492) -Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460) -Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) -Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) -Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) -Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259) -Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565) +- Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644) +- Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484) +- SSH support [#1487](https://github.com/juanfont/headscale/pull/1487) +- State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492) +- Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460) +- Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) +- Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) +- Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) +- Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259) +- Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565) +- Change the structure of database configuration, see [config-example.yaml](./config-example.yaml) for the new structure. [#1700](https://github.com/juanfont/headscale/pull/1700) + - Old structure is now considered deprecated and will be removed in the future. ## 0.22.3 (2023-05-12) diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 897e253..d73d30b 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -58,8 +58,10 @@ func (*Suite) TestConfigFileLoading(c *check.C) { c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080") c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080") c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090") - c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") + c.Assert(viper.GetString("db_type"), check.Equals, "sqlite") c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite") + c.Assert(viper.GetString("database.type"), check.Equals, "sqlite") + c.Assert(viper.GetString("database.sqlite.path"), check.Equals, "/var/lib/headscale/db.sqlite") c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http") c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") @@ -101,7 +103,7 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080") c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080") c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090") - c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") + c.Assert(viper.GetString("db_type"), check.Equals, "sqlite") c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite") c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http") diff --git a/config-example.yaml b/config-example.yaml index 96a654a..8e4373f 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -138,24 +138,25 @@ ephemeral_node_inactivity_timeout: 30m # In case of doubts, do not touch the default 10s. node_update_check_interval: 10s -# SQLite config -db_type: sqlite3 +database: + type: sqlite -# For production: -db_path: /var/lib/headscale/db.sqlite + # SQLite config + sqlite: + path: /var/lib/headscale/db.sqlite -# # Postgres config -# If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank. -# db_type: postgres -# db_host: localhost -# db_port: 5432 -# db_name: headscale -# db_user: foo -# db_pass: bar + # # Postgres config + # postgres: + # # If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank. + # host: localhost + # port: 5432 + # name: headscale + # user: foo + # pass: bar -# If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need -# in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1. -# db_ssl: false + # # If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need + # # in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1. + # ssl: false ### TLS configuration # diff --git a/hscontrol/app.go b/hscontrol/app.go index 91d5326..78b72bf 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -12,7 +12,6 @@ import ( "os" "os/signal" "runtime" - "strconv" "strings" "sync" "syscall" @@ -118,37 +117,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) } - var dbString string - switch cfg.DBtype { - case db.Postgres: - dbString = fmt.Sprintf( - "host=%s dbname=%s user=%s", - cfg.DBhost, - cfg.DBname, - cfg.DBuser, - ) - - if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil { - if !sslEnabled { - dbString += " sslmode=disable" - } - } else { - dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl) - } - - if cfg.DBport != 0 { - dbString += fmt.Sprintf(" port=%d", cfg.DBport) - } - - if cfg.DBpass != "" { - dbString += fmt.Sprintf(" password=%s", cfg.DBpass) - } - case db.Sqlite: - dbString = cfg.DBpath - default: - return nil, errUnsupportedDatabase - } - registrationCache := cache.New( registerCacheExpiration, registerCacheCleanup, @@ -156,8 +124,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app := Headscale{ cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, noisePrivateKey: noisePrivateKey, registrationCache: registrationCache, pollNetMapStreamWG: sync.WaitGroup{}, @@ -165,9 +131,8 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } database, err := db.NewHeadscaleDatabase( - cfg.DBtype, - dbString, - app.dbDebug, + cfg.Database, + app.nodeNotifier, cfg.IPPrefixes, cfg.BaseDomain) if err != nil { @@ -755,14 +720,16 @@ func (h *Headscale) Serve() error { var tailsqlContext context.Context if tailsqlEnabled { - if h.cfg.DBtype != db.Sqlite { - log.Fatal().Str("type", h.cfg.DBtype).Msgf("tailsql only support %q", db.Sqlite) + if h.cfg.Database.Type != types.DatabaseSqlite { + log.Fatal(). + Str("type", h.cfg.Database.Type). + Msgf("tailsql only support %q", types.DatabaseSqlite) } if tailsqlTSKey == "" { log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set") } tailsqlContext = context.Background() - go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.DBpath) + go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path) } // Handle common process-killing signals so we can gracefully shut down: diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index df7b0a4..fe77dda 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -6,11 +6,13 @@ import ( "errors" "fmt" "net/netip" + "strconv" "strings" "time" "github.com/glebarez/sqlite" "github.com/go-gormigrate/gormigrate/v2" + "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -19,11 +21,6 @@ import ( "gorm.io/gorm/logger" ) -const ( - Postgres = "postgres" - Sqlite = "sqlite3" -) - var errDatabaseNotSupported = errors.New("database type not supported") // KV is a key-value store in a psql table. For future use... @@ -43,12 +40,12 @@ type HSDatabase struct { // TODO(kradalby): assemble this struct from toptions or something typed // rather than arguments. func NewHeadscaleDatabase( - dbType, connectionAddr string, - debug bool, + cfg types.DatabaseConfig, + notifier *notifier.Notifier, ipPrefixes []netip.Prefix, baseDomain string, ) (*HSDatabase, error) { - dbConn, err := openDB(dbType, connectionAddr, debug) + dbConn, err := openDB(cfg) if err != nil { return nil, err } @@ -62,7 +59,7 @@ func NewHeadscaleDatabase( { ID: "202312101416", Migrate: func(tx *gorm.DB) error { - if dbType == Postgres { + if cfg.Type == types.DatabasePostgres { tx.Exec(`create extension if not exists "uuid-ossp";`) } @@ -321,20 +318,20 @@ func NewHeadscaleDatabase( return &db, err } -func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { - log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database") +func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { + // TODO(kradalby): Integrate this with zerolog var dbLogger logger.Interface - if debug { + if cfg.Debug { dbLogger = logger.Default } else { dbLogger = logger.Default.LogMode(logger.Silent) } - switch dbType { - case Sqlite: + switch cfg.Type { + case types.DatabaseSqlite: db, err := gorm.Open( - sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"), + sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: dbLogger, @@ -353,8 +350,31 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { return db, err - case Postgres: - return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{ + case types.DatabasePostgres: + dbString := fmt.Sprintf( + "host=%s dbname=%s user=%s", + cfg.Postgres.Host, + cfg.Postgres.Name, + cfg.Postgres.User, + ) + + if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { + if !sslEnabled { + dbString += " sslmode=disable" + } + } else { + dbString += fmt.Sprintf(" sslmode=%s", cfg.Postgres.Ssl) + } + + if cfg.Postgres.Port != 0 { + dbString += fmt.Sprintf(" port=%d", cfg.Postgres.Port) + } + + if cfg.Postgres.Pass != "" { + dbString += fmt.Sprintf(" password=%s", cfg.Postgres.Pass) + } + + return gorm.Open(postgres.Open(dbString), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: dbLogger, }) @@ -362,7 +382,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { return nil, fmt.Errorf( "database of type %s is not supported: %w", - dbType, + cfg.Type, errDatabaseNotSupported, ) } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 3b544aa..5d6281e 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" @@ -654,9 +655,13 @@ func TestFailoverRoute(t *testing.T) { assert.NoError(t, err) db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, + types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + notifier.NewNotifier(), []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index d4b11b1..e176e4b 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -6,6 +6,8 @@ import ( "os" "testing" + "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/types" "gopkg.in/check.v1" ) @@ -44,9 +46,13 @@ func (s *Suite) ResetDB(c *check.C) { log.Printf("database path: %s", tmpDir+"/headscale_test.db") db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, + types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + notifier.NewNotifier(), []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/suite_test.go b/hscontrol/suite_test.go index 82bdc79..3f0cc42 100644 --- a/hscontrol/suite_test.go +++ b/hscontrol/suite_test.go @@ -41,8 +41,12 @@ func (s *Suite) ResetDB(c *check.C) { } cfg := types.Config{ NoisePrivateKeyPath: tmpDir + "/noise_private.key", - DBtype: "sqlite3", - DBpath: tmpDir + "/headscale_test.db", + Database: types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, IPPrefixes: []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index d45f9d4..ceeceea 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -12,7 +12,11 @@ import ( "tailscale.com/tailcfg" ) -const SelfUpdateIdentifier = "self-update" +const ( + SelfUpdateIdentifier = "self-update" + DatabasePostgres = "postgres" + DatabaseSqlite = "sqlite3" +) var ErrCannotParsePrefix = errors.New("cannot parse prefix") @@ -154,7 +158,9 @@ func (su *StateUpdate) Valid() bool { } case StateSelfUpdate: if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 { - panic("Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node") + panic( + "Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node", + ) } case StateDERPUpdated: if su.DERPMap == nil { diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index d9d5830..d83b21f 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -46,16 +46,9 @@ type Config struct { Log LogConfig DisableUpdateCheck bool - DERP DERPConfig + Database DatabaseConfig - DBtype string - DBpath string - DBhost string - DBport int - DBname string - DBuser string - DBpass string - DBssl string + DERP DERPConfig TLS TLSConfig @@ -77,6 +70,28 @@ type Config struct { ACL ACLConfig } +type SqliteConfig struct { + Path string +} + +type PostgresConfig struct { + Host string + Port int + Name string + User string + Pass string + Ssl string +} + +type DatabaseConfig struct { + // Type sets the database type, either "sqlite3" or "postgres" + Type string + Debug bool + + Sqlite SqliteConfig + Postgres PostgresConfig +} + type TLSConfig struct { CertPath string KeyPath string @@ -161,6 +176,19 @@ func LoadConfig(path string, isFile bool) error { viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) viper.AutomaticEnv() + viper.RegisterAlias("db_type", "database.type") + + // SQLite aliases + viper.RegisterAlias("db_path", "database.sqlite.path") + + // Postgres aliases + viper.RegisterAlias("db_host", "database.postgres.host") + viper.RegisterAlias("db_port", "database.postgres.port") + viper.RegisterAlias("db_name", "database.postgres.name") + viper.RegisterAlias("db_user", "database.postgres.user") + viper.RegisterAlias("db_pass", "database.postgres.pass") + viper.RegisterAlias("db_ssl", "database.postgres.ssl") + viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType) @@ -184,6 +212,7 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("cli.insecure", false) viper.SetDefault("db_ssl", false) + viper.SetDefault("database.postgres.ssl", false) viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) viper.SetDefault("oidc.strip_email_domain", true) @@ -389,6 +418,37 @@ func GetLogConfig() LogConfig { } } +func GetDatabaseConfig() DatabaseConfig { + debug := viper.GetBool("database.debug") + + type_ := viper.GetString("database.type") + + switch type_ { + case DatabaseSqlite, DatabasePostgres: + break + case "sqlite": + type_ = "sqlite3" + default: + log.Fatal().Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_) + } + + return DatabaseConfig{ + Type: type_, + Debug: debug, + Sqlite: SqliteConfig{ + Path: util.AbsolutePathFromConfigPath(viper.GetString("database.sqlite.path")), + }, + Postgres: PostgresConfig{ + Host: viper.GetString("database.postgres.host"), + Port: viper.GetInt("database.postgres.port"), + Name: viper.GetString("database.postgres.name"), + User: viper.GetString("database.postgres.user"), + Pass: viper.GetString("database.postgres.pass"), + Ssl: viper.GetString("database.postgres.ssl"), + }, + } +} + func GetDNSConfig() (*tailcfg.DNSConfig, string) { if viper.IsSet("dns_config") { dnsConfig := &tailcfg.DNSConfig{} @@ -617,14 +677,7 @@ func GetHeadscaleConfig() (*Config, error) { "node_update_check_interval", ), - DBtype: viper.GetString("db_type"), - DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")), - DBhost: viper.GetString("db_host"), - DBport: viper.GetInt("db_port"), - DBname: viper.GetString("db_name"), - DBuser: viper.GetString("db_user"), - DBpass: viper.GetString("db_pass"), - DBssl: viper.GetString("db_ssl"), + Database: GetDatabaseConfig(), TLS: GetTLSConfig(), diff --git a/integration/hsic/config.go b/integration/hsic/config.go index 00c1770..819b108 100644 --- a/integration/hsic/config.go +++ b/integration/hsic/config.go @@ -110,8 +110,8 @@ func DefaultConfigEnv() map[string]string { return map[string]string{ "HEADSCALE_LOG_LEVEL": "trace", "HEADSCALE_ACL_POLICY_PATH": "", - "HEADSCALE_DB_TYPE": "sqlite3", - "HEADSCALE_DB_PATH": "/tmp/integration_test_db.sqlite3", + "HEADSCALE_DATABASE_TYPE": "sqlite", + "HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3", "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m", "HEADSCALE_NODE_UPDATE_CHECK_INTERVAL": "10s", "HEADSCALE_IP_PREFIXES": "fd7a:115c:a1e0::/48 100.64.0.0/10", From 91bb85e7d2aff8aec2bc42050d7da02be0353d75 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 9 Feb 2024 07:27:13 +0100 Subject: [PATCH 11/13] Update bug_report.md (#1672) --- .github/ISSUE_TEMPLATE/bug_report.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 02e4742..8563e7a 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -50,3 +50,16 @@ instead of filing a bug report. ## To Reproduce + +## Logs and attachments + + From 9047c09871dec7f84a00311e3ee14365f870810d Mon Sep 17 00:00:00 2001 From: Pallab Pain Date: Fri, 9 Feb 2024 22:04:28 +0530 Subject: [PATCH 12/13] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20pqsql=20configs?= =?UTF-8?q?=20for=20open=20and=20idle=20connections=20(#1583)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When Postgres is used as the backing database for headscale, it does not set a limit on maximum open and idle connections which leads to hundreds of open connections to the Postgres server. This commit introduces the configuration variables to set those values and also sets default while opening a new postgres connection. --- CHANGELOG.md | 1 + config-example.yaml | 3 + hscontrol/db/db.go | 478 ++++++++++++++++++++------------------ hscontrol/types/config.go | 47 ++-- 4 files changed, 287 insertions(+), 242 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3adb23b..9c1d04b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565) - Change the structure of database configuration, see [config-example.yaml](./config-example.yaml) for the new structure. [#1700](https://github.com/juanfont/headscale/pull/1700) - Old structure is now considered deprecated and will be removed in the future. + - Adds additional configuration for PostgreSQL for setting max open, idle conection and idle connection lifetime. ## 0.22.3 (2023-05-12) diff --git a/config-example.yaml b/config-example.yaml index 8e4373f..d41771f 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -153,6 +153,9 @@ database: # name: headscale # user: foo # pass: bar + # max_open_conns: 10 + # max_idle_conns: 10 + # conn_max_idle_time_secs: 3600 # # If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need # # in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1. diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index fe77dda..4ded07f 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -12,13 +12,14 @@ import ( "github.com/glebarez/sqlite" "github.com/go-gormigrate/gormigrate/v2" - "github.com/juanfont/headscale/hscontrol/notifier" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" + + "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" ) var errDatabaseNotSupported = errors.New("database type not supported") @@ -50,259 +51,273 @@ func NewHeadscaleDatabase( return nil, err } - migrations := gormigrate.New(dbConn, gormigrate.DefaultOptions, []*gormigrate.Migration{ - // New migrations should be added as transactions at the end of this list. - // The initial commit here is quite messy, completely out of order and - // has no versioning and is the tech debt of not having versioned migrations - // prior to this point. This first migration is all DB changes to bring a DB - // up to 0.23.0. - { - ID: "202312101416", - Migrate: func(tx *gorm.DB) error { - if cfg.Type == types.DatabasePostgres { - tx.Exec(`create extension if not exists "uuid-ossp";`) - } - - _ = tx.Migrator().RenameTable("namespaces", "users") - - // the big rename from Machine to Node - _ = tx.Migrator().RenameTable("machines", "nodes") - _ = tx.Migrator().RenameColumn(&types.Route{}, "machine_id", "node_id") - - err = tx.AutoMigrate(types.User{}) - if err != nil { - return err - } - - _ = tx.Migrator().RenameColumn(&types.Node{}, "namespace_id", "user_id") - _ = tx.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id") - - _ = tx.Migrator().RenameColumn(&types.Node{}, "ip_address", "ip_addresses") - _ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname") - - // GivenName is used as the primary source of DNS names, make sure - // the field is populated and normalized if it was not when the - // node was registered. - _ = tx.Migrator().RenameColumn(&types.Node{}, "nickname", "given_name") - - // If the Node table has a column for registered, - // find all occourences of "false" and drop them. Then - // remove the column. - if tx.Migrator().HasColumn(&types.Node{}, "registered") { - log.Info(). - Msg(`Database has legacy "registered" column in node, removing...`) - - nodes := types.Nodes{} - if err := tx.Not("registered").Find(&nodes).Error; err != nil { - log.Error().Err(err).Msg("Error accessing db") + migrations := gormigrate.New( + dbConn, + gormigrate.DefaultOptions, + []*gormigrate.Migration{ + // New migrations should be added as transactions at the end of this list. + // The initial commit here is quite messy, completely out of order and + // has no versioning and is the tech debt of not having versioned migrations + // prior to this point. This first migration is all DB changes to bring a DB + // up to 0.23.0. + { + ID: "202312101416", + Migrate: func(tx *gorm.DB) error { + if cfg.Type == types.DatabasePostgres { + tx.Exec(`create extension if not exists "uuid-ossp";`) } - for _, node := range nodes { - log.Info(). - Str("node", node.Hostname). - Str("machine_key", node.MachineKey.ShortString()). - Msg("Deleting unregistered node") - if err := tx.Delete(&types.Node{}, node.ID).Error; err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Str("machine_key", node.MachineKey.ShortString()). - Msg("Error deleting unregistered node") - } - } + _ = tx.Migrator().RenameTable("namespaces", "users") - err := tx.Migrator().DropColumn(&types.Node{}, "registered") - if err != nil { - log.Error().Err(err).Msg("Error dropping registered column") - } - } + // the big rename from Machine to Node + _ = tx.Migrator().RenameTable("machines", "nodes") + _ = tx.Migrator(). + RenameColumn(&types.Route{}, "machine_id", "node_id") - err = tx.AutoMigrate(&types.Route{}) - if err != nil { - return err - } - - err = tx.AutoMigrate(&types.Node{}) - if err != nil { - return err - } - - // Ensure all keys have correct prefixes - // https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35 - type result struct { - ID uint64 - MachineKey string - NodeKey string - DiscoKey string - } - var results []result - err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes"). - Find(&results). - Error - if err != nil { - return err - } - - for _, node := range results { - mKey := node.MachineKey - if !strings.HasPrefix(node.MachineKey, "mkey:") { - mKey = "mkey:" + node.MachineKey - } - nKey := node.NodeKey - if !strings.HasPrefix(node.NodeKey, "nodekey:") { - nKey = "nodekey:" + node.NodeKey - } - - dKey := node.DiscoKey - if !strings.HasPrefix(node.DiscoKey, "discokey:") { - dKey = "discokey:" + node.DiscoKey - } - - err := tx.Exec( - "UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id", - sql.Named("mKey", mKey), - sql.Named("nKey", nKey), - sql.Named("dKey", dKey), - sql.Named("id", node.ID), - ).Error + err = tx.AutoMigrate(types.User{}) if err != nil { return err } - } - if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") { - log.Info(). - Msgf("Database has legacy enabled_routes column in node, migrating...") + _ = tx.Migrator(). + RenameColumn(&types.Node{}, "namespace_id", "user_id") + _ = tx.Migrator(). + RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id") - type NodeAux struct { - ID uint64 - EnabledRoutes types.IPPrefixes - } + _ = tx.Migrator(). + RenameColumn(&types.Node{}, "ip_address", "ip_addresses") + _ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname") - nodesAux := []NodeAux{} - err := tx.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error - if err != nil { - log.Fatal().Err(err).Msg("Error accessing db") - } - for _, node := range nodesAux { - for _, prefix := range node.EnabledRoutes { - if err != nil { + // GivenName is used as the primary source of DNS names, make sure + // the field is populated and normalized if it was not when the + // node was registered. + _ = tx.Migrator(). + RenameColumn(&types.Node{}, "nickname", "given_name") + + // If the Node table has a column for registered, + // find all occourences of "false" and drop them. Then + // remove the column. + if tx.Migrator().HasColumn(&types.Node{}, "registered") { + log.Info(). + Msg(`Database has legacy "registered" column in node, removing...`) + + nodes := types.Nodes{} + if err := tx.Not("registered").Find(&nodes).Error; err != nil { + log.Error().Err(err).Msg("Error accessing db") + } + + for _, node := range nodes { + log.Info(). + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Msg("Deleting unregistered node") + if err := tx.Delete(&types.Node{}, node.ID).Error; err != nil { log.Error(). Err(err). - Str("enabled_route", prefix.String()). - Msg("Error parsing enabled_route") - - continue + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Msg("Error deleting unregistered node") } + } - err = tx.Preload("Node"). - Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). - First(&types.Route{}). - Error - if err == nil { - log.Info(). - Str("enabled_route", prefix.String()). - Msg("Route already migrated to new table, skipping") + err := tx.Migrator().DropColumn(&types.Node{}, "registered") + if err != nil { + log.Error().Err(err).Msg("Error dropping registered column") + } + } - continue + err = tx.AutoMigrate(&types.Route{}) + if err != nil { + return err + } + + err = tx.AutoMigrate(&types.Node{}) + if err != nil { + return err + } + + // Ensure all keys have correct prefixes + // https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35 + type result struct { + ID uint64 + MachineKey string + NodeKey string + DiscoKey string + } + var results []result + err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes"). + Find(&results). + Error + if err != nil { + return err + } + + for _, node := range results { + mKey := node.MachineKey + if !strings.HasPrefix(node.MachineKey, "mkey:") { + mKey = "mkey:" + node.MachineKey + } + nKey := node.NodeKey + if !strings.HasPrefix(node.NodeKey, "nodekey:") { + nKey = "nodekey:" + node.NodeKey + } + + dKey := node.DiscoKey + if !strings.HasPrefix(node.DiscoKey, "discokey:") { + dKey = "discokey:" + node.DiscoKey + } + + err := tx.Exec( + "UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id", + sql.Named("mKey", mKey), + sql.Named("nKey", nKey), + sql.Named("dKey", dKey), + sql.Named("id", node.ID), + ).Error + if err != nil { + return err + } + } + + if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") { + log.Info(). + Msgf("Database has legacy enabled_routes column in node, migrating...") + + type NodeAux struct { + ID uint64 + EnabledRoutes types.IPPrefixes + } + + nodesAux := []NodeAux{} + err := tx.Table("nodes"). + Select("id, enabled_routes"). + Scan(&nodesAux). + Error + if err != nil { + log.Fatal().Err(err).Msg("Error accessing db") + } + for _, node := range nodesAux { + for _, prefix := range node.EnabledRoutes { + if err != nil { + log.Error(). + Err(err). + Str("enabled_route", prefix.String()). + Msg("Error parsing enabled_route") + + continue + } + + err = tx.Preload("Node"). + Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). + First(&types.Route{}). + Error + if err == nil { + log.Info(). + Str("enabled_route", prefix.String()). + Msg("Route already migrated to new table, skipping") + + continue + } + + route := types.Route{ + NodeID: node.ID, + Advertised: true, + Enabled: true, + Prefix: types.IPPrefix(prefix), + } + if err := tx.Create(&route).Error; err != nil { + log.Error().Err(err).Msg("Error creating route") + } else { + log.Info(). + Uint64("node_id", route.NodeID). + Str("prefix", prefix.String()). + Msg("Route migrated") + } } + } - route := types.Route{ - NodeID: node.ID, - Advertised: true, - Enabled: true, - Prefix: types.IPPrefix(prefix), - } - if err := tx.Create(&route).Error; err != nil { - log.Error().Err(err).Msg("Error creating route") - } else { - log.Info(). - Uint64("node_id", route.NodeID). - Str("prefix", prefix.String()). - Msg("Route migrated") + err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes") + if err != nil { + log.Error(). + Err(err). + Msg("Error dropping enabled_routes column") + } + } + + if tx.Migrator().HasColumn(&types.Node{}, "given_name") { + nodes := types.Nodes{} + if err := tx.Find(&nodes).Error; err != nil { + log.Error().Err(err).Msg("Error accessing db") + } + + for item, node := range nodes { + if node.GivenName == "" { + normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( + node.Hostname, + ) + if err != nil { + log.Error(). + Caller(). + Str("hostname", node.Hostname). + Err(err). + Msg("Failed to normalize node hostname in DB migration") + } + + err = tx.Model(nodes[item]).Updates(types.Node{ + GivenName: normalizedHostname, + }).Error + if err != nil { + log.Error(). + Caller(). + Str("hostname", node.Hostname). + Err(err). + Msg("Failed to save normalized node name in DB migration") + } } } } - err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes") + err = tx.AutoMigrate(&KV{}) if err != nil { - log.Error().Err(err).Msg("Error dropping enabled_routes column") - } - } - - if tx.Migrator().HasColumn(&types.Node{}, "given_name") { - nodes := types.Nodes{} - if err := tx.Find(&nodes).Error; err != nil { - log.Error().Err(err).Msg("Error accessing db") + return err } - for item, node := range nodes { - if node.GivenName == "" { - normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( - node.Hostname, - ) - if err != nil { - log.Error(). - Caller(). - Str("hostname", node.Hostname). - Err(err). - Msg("Failed to normalize node hostname in DB migration") - } - - err = tx.Model(nodes[item]).Updates(types.Node{ - GivenName: normalizedHostname, - }).Error - if err != nil { - log.Error(). - Caller(). - Str("hostname", node.Hostname). - Err(err). - Msg("Failed to save normalized node name in DB migration") - } - } + err = tx.AutoMigrate(&types.PreAuthKey{}) + if err != nil { + return err } - } - err = tx.AutoMigrate(&KV{}) - if err != nil { - return err - } + err = tx.AutoMigrate(&types.PreAuthKeyACLTag{}) + if err != nil { + return err + } - err = tx.AutoMigrate(&types.PreAuthKey{}) - if err != nil { - return err - } + _ = tx.Migrator().DropTable("shared_machines") - err = tx.AutoMigrate(&types.PreAuthKeyACLTag{}) - if err != nil { - return err - } + err = tx.AutoMigrate(&types.APIKey{}) + if err != nil { + return err + } - _ = tx.Migrator().DropTable("shared_machines") - - err = tx.AutoMigrate(&types.APIKey{}) - if err != nil { - return err - } - - return nil + return nil + }, + Rollback: func(tx *gorm.DB) error { + return nil + }, }, - Rollback: func(tx *gorm.DB) error { - return nil + { + // drop key-value table, it is not used, and has not contained + // useful data for a long time or ever. + ID: "202312101430", + Migrate: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("kvs") + }, + Rollback: func(tx *gorm.DB) error { + return nil + }, }, }, - { - // drop key-value table, it is not used, and has not contained - // useful data for a long time or ever. - ID: "202312101430", - Migrate: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("kvs") - }, - Rollback: func(tx *gorm.DB) error { - return nil - }, - }, - }) + ) if err = migrations.Migrate(); err != nil { log.Fatal().Err(err).Msgf("Migration failed: %v", err) @@ -319,7 +334,6 @@ func NewHeadscaleDatabase( } func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { - // TODO(kradalby): Integrate this with zerolog var dbLogger logger.Interface if cfg.Debug { @@ -374,10 +388,22 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { dbString += fmt.Sprintf(" password=%s", cfg.Postgres.Pass) } - return gorm.Open(postgres.Open(dbString), &gorm.Config{ + db, err := gorm.Open(postgres.Open(dbString), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: dbLogger, }) + if err != nil { + return nil, err + } + + sqlDB, _ := db.DB() + sqlDB.SetMaxIdleConns(cfg.Postgres.MaxIdleConnections) + sqlDB.SetMaxOpenConns(cfg.Postgres.MaxOpenConnections) + sqlDB.SetConnMaxIdleTime( + time.Duration(cfg.Postgres.ConnMaxIdleTimeSecs) * time.Second, + ) + + return db, nil } return nil, fmt.Errorf( diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index d83b21f..a82218e 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -11,7 +11,6 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -20,6 +19,8 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + + "github.com/juanfont/headscale/hscontrol/util" ) const ( @@ -75,12 +76,15 @@ type SqliteConfig struct { } type PostgresConfig struct { - Host string - Port int - Name string - User string - Pass string - Ssl string + Host string + Port int + Name string + User string + Pass string + Ssl string + MaxOpenConnections int + MaxIdleConnections int + ConnMaxIdleTimeSecs int } type DatabaseConfig struct { @@ -213,6 +217,9 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("db_ssl", false) viper.SetDefault("database.postgres.ssl", false) + viper.SetDefault("database.postgres.max_open_conns", 10) + viper.SetDefault("database.postgres.max_idle_conns", 10) + viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600) viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) viper.SetDefault("oidc.strip_email_domain", true) @@ -287,7 +294,7 @@ func LoadConfig(path string, isFile bool) error { } if errorText != "" { - //nolint + // nolint return errors.New(strings.TrimSuffix(errorText, "\n")) } else { return nil @@ -429,22 +436,30 @@ func GetDatabaseConfig() DatabaseConfig { case "sqlite": type_ = "sqlite3" default: - log.Fatal().Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_) + log.Fatal(). + Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_) } return DatabaseConfig{ Type: type_, Debug: debug, Sqlite: SqliteConfig{ - Path: util.AbsolutePathFromConfigPath(viper.GetString("database.sqlite.path")), + Path: util.AbsolutePathFromConfigPath( + viper.GetString("database.sqlite.path"), + ), }, Postgres: PostgresConfig{ - Host: viper.GetString("database.postgres.host"), - Port: viper.GetInt("database.postgres.port"), - Name: viper.GetString("database.postgres.name"), - User: viper.GetString("database.postgres.user"), - Pass: viper.GetString("database.postgres.pass"), - Ssl: viper.GetString("database.postgres.ssl"), + Host: viper.GetString("database.postgres.host"), + Port: viper.GetInt("database.postgres.port"), + Name: viper.GetString("database.postgres.name"), + User: viper.GetString("database.postgres.user"), + Pass: viper.GetString("database.postgres.pass"), + Ssl: viper.GetString("database.postgres.ssl"), + MaxOpenConnections: viper.GetInt("database.postgres.max_open_conns"), + MaxIdleConnections: viper.GetInt("database.postgres.max_idle_conns"), + ConnMaxIdleTimeSecs: viper.GetInt( + "database.postgres.conn_max_idle_time_secs", + ), }, } } From c3257e2146304c52e588c6de2fd28bcc0f13b1ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?l=C3=B6=C3=B6ps?= Date: Fri, 9 Feb 2024 13:16:17 -0500 Subject: [PATCH 13/13] docs(windows-client): add Windows registry command (#1658) Add Windows registry command to create the `Tailscale IPN` path before setting properties. --- docs/windows-client.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/windows-client.md b/docs/windows-client.md index fcb8c0e..38d330b 100644 --- a/docs/windows-client.md +++ b/docs/windows-client.md @@ -18,6 +18,7 @@ You can set these using the Windows Registry Editor: Or via the following Powershell commands (right click Powershell icon and select "Run as administrator"): ``` +New-Item -Path "HKLM:\SOFTWARE\Tailscale IPN" New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value https://YOUR-HEADSCALE-URL ```