diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 02e4742..8563e7a 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -50,3 +50,16 @@ instead of filing a bug report. ## To Reproduce + +## Logs and attachments + + diff --git a/.github/workflows/test-integration-v2-TestEnableDisableAutoApprovedRoute.yaml b/.github/workflows/test-integration-v2-TestEnableDisableAutoApprovedRoute.yaml new file mode 100644 index 0000000..def07cc --- /dev/null +++ b/.github/workflows/test-integration-v2-TestEnableDisableAutoApprovedRoute.yaml @@ -0,0 +1,67 @@ +# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go +# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ + +name: Integration Test v2 - TestEnableDisableAutoApprovedRoute + +on: [pull_request] + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + TestEnableDisableAutoApprovedRoute: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - uses: DeterminateSystems/nix-installer-action@main + - uses: DeterminateSystems/magic-nix-cache-action@main + - uses: satackey/action-docker-layer-caching@main + continue-on-error: true + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v34 + with: + files: | + *.nix + go.* + **/*.go + integration_test/ + config-example.yaml + + - name: Run TestEnableDisableAutoApprovedRoute + uses: Wandalen/wretry.action@master + if: steps.changed-files.outputs.any_changed == 'true' + with: + attempt_limit: 5 + command: | + nix develop --command -- docker run \ + --tty --rm \ + --volume ~/.cache/hs-integration-go:/go \ + --name headscale-test-suite \ + --volume $PWD:$PWD -w $PWD/integration \ + --volume /var/run/docker.sock:/var/run/docker.sock \ + --volume $PWD/control_logs:/tmp/control \ + golang:1 \ + go run gotest.tools/gotestsum@latest -- ./... \ + -failfast \ + -timeout 120m \ + -parallel 1 \ + -run "^TestEnableDisableAutoApprovedRoute$" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: logs + path: "control_logs/*.log" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: pprof + path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestPingAllByIPPublicDERP.yaml b/.github/workflows/test-integration-v2-TestPingAllByIPPublicDERP.yaml new file mode 100644 index 0000000..18fd341 --- /dev/null +++ b/.github/workflows/test-integration-v2-TestPingAllByIPPublicDERP.yaml @@ -0,0 +1,67 @@ +# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go +# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ + +name: Integration Test v2 - TestPingAllByIPPublicDERP + +on: [pull_request] + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + TestPingAllByIPPublicDERP: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - uses: DeterminateSystems/nix-installer-action@main + - uses: DeterminateSystems/magic-nix-cache-action@main + - uses: satackey/action-docker-layer-caching@main + continue-on-error: true + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v34 + with: + files: | + *.nix + go.* + **/*.go + integration_test/ + config-example.yaml + + - name: Run TestPingAllByIPPublicDERP + uses: Wandalen/wretry.action@master + if: steps.changed-files.outputs.any_changed == 'true' + with: + attempt_limit: 5 + command: | + nix develop --command -- docker run \ + --tty --rm \ + --volume ~/.cache/hs-integration-go:/go \ + --name headscale-test-suite \ + --volume $PWD:$PWD -w $PWD/integration \ + --volume /var/run/docker.sock:/var/run/docker.sock \ + --volume $PWD/control_logs:/tmp/control \ + golang:1 \ + go run gotest.tools/gotestsum@latest -- ./... \ + -failfast \ + -timeout 120m \ + -parallel 1 \ + -run "^TestPingAllByIPPublicDERP$" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: logs + path: "control_logs/*.log" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: pprof + path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestSubnetRouteACL.yaml b/.github/workflows/test-integration-v2-TestSubnetRouteACL.yaml new file mode 100644 index 0000000..3cb3f11 --- /dev/null +++ b/.github/workflows/test-integration-v2-TestSubnetRouteACL.yaml @@ -0,0 +1,67 @@ +# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go +# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ + +name: Integration Test v2 - TestSubnetRouteACL + +on: [pull_request] + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + TestSubnetRouteACL: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - uses: DeterminateSystems/nix-installer-action@main + - uses: DeterminateSystems/magic-nix-cache-action@main + - uses: satackey/action-docker-layer-caching@main + continue-on-error: true + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v34 + with: + files: | + *.nix + go.* + **/*.go + integration_test/ + config-example.yaml + + - name: Run TestSubnetRouteACL + uses: Wandalen/wretry.action@master + if: steps.changed-files.outputs.any_changed == 'true' + with: + attempt_limit: 5 + command: | + nix develop --command -- docker run \ + --tty --rm \ + --volume ~/.cache/hs-integration-go:/go \ + --name headscale-test-suite \ + --volume $PWD:$PWD -w $PWD/integration \ + --volume /var/run/docker.sock:/var/run/docker.sock \ + --volume $PWD/control_logs:/tmp/control \ + golang:1 \ + go run gotest.tools/gotestsum@latest -- ./... \ + -failfast \ + -timeout 120m \ + -parallel 1 \ + -run "^TestSubnetRouteACL$" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: logs + path: "control_logs/*.log" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: pprof + path: "control_logs/*.pprof.tar" diff --git a/CHANGELOG.md b/CHANGELOG.md index 6484bf3..03d1563 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473) - API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553) - Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611) - - The latest supported client is 1.36 + - The latest supported client is 1.38 - Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564) - If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url. - Embedded DERP server requires a private key [#1611](https://github.com/juanfont/headscale/pull/1611) @@ -34,17 +34,21 @@ after improving the test harness as part of adopting [#1460](https://github.com/ ### Changes -Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644) -Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484) -Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287) -SSH support [#1487](https://github.com/juanfont/headscale/pull/1487) -State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492) -Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460) -Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) -Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) -Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) -Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259) -Add `oidc.groups_claim`, `oidc.email_claim`, and `oidc.username_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594) +- Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644) +- Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484) +- SSH support [#1487](https://github.com/juanfont/headscale/pull/1487) +- State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492) +- Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460) +- Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) +- Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) +- Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) +- Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259) +- Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565) +- Change the structure of database configuration, see [config-example.yaml](./config-example.yaml) for the new structure. [#1700](https://github.com/juanfont/headscale/pull/1700) + - Old structure is now considered deprecated and will be removed in the future. + - Adds additional configuration for PostgreSQL for setting max open, idle conection and idle connection lifetime. +- Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287) +- Add `oidc.groups_claim`, `oidc.email_claim`, and `oidc.username_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594) ## 0.22.3 (2023-05-12) diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index dfaf512..3f3322e 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -6,25 +6,11 @@ import ( "github.com/efekarakus/termcolor" "github.com/juanfont/headscale/cmd/headscale/cli" - "github.com/pkg/profile" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) func main() { - if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile { - if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok { - err := os.MkdirAll(profilePath, os.ModePerm) - if err != nil { - log.Fatal().Err(err).Msg("failed to create profiling directory") - } - - defer profile.Start(profile.ProfilePath(profilePath)).Stop() - } else { - defer profile.Start().Stop() - } - } - var colors bool switch l := termcolor.SupportLevel(os.Stderr); l { case termcolor.Level16M: diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 897e253..d73d30b 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -58,8 +58,10 @@ func (*Suite) TestConfigFileLoading(c *check.C) { c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080") c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080") c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090") - c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") + c.Assert(viper.GetString("db_type"), check.Equals, "sqlite") c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite") + c.Assert(viper.GetString("database.type"), check.Equals, "sqlite") + c.Assert(viper.GetString("database.sqlite.path"), check.Equals, "/var/lib/headscale/db.sqlite") c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http") c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") @@ -101,7 +103,7 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080") c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080") c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090") - c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") + c.Assert(viper.GetString("db_type"), check.Equals, "sqlite") c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite") c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http") diff --git a/config-example.yaml b/config-example.yaml index 5325840..65a5c9f 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -94,6 +94,16 @@ derp: # private_key_path: /var/lib/headscale/derp_server_private.key + # This flag can be used, so the DERP map entry for the embedded DERP server is not written automatically, + # it enables the creation of your very own DERP map entry using a locally available file with the parameter DERP.paths + # If you enable the DERP server and set this to false, it is required to add the DERP server to the DERP map using DERP.paths + automatically_add_embedded_derp_region: true + + # For better connection stability (especially when using an Exit-Node and DNS is not working), + # it is possible to optionall add the public IPv4 and IPv6 address to the Derp-Map using: + ipv4: 1.2.3.4 + ipv6: 2001:db8::1 + # List of externally available DERP maps encoded in JSON urls: - https://controlplane.tailscale.com/derpmap/default @@ -128,24 +138,28 @@ ephemeral_node_inactivity_timeout: 30m # In case of doubts, do not touch the default 10s. node_update_check_interval: 10s -# SQLite config -db_type: sqlite3 +database: + type: sqlite -# For production: -db_path: /var/lib/headscale/db.sqlite + # SQLite config + sqlite: + path: /var/lib/headscale/db.sqlite -# # Postgres config -# If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank. -# db_type: postgres -# db_host: localhost -# db_port: 5432 -# db_name: headscale -# db_user: foo -# db_pass: bar + # # Postgres config + # postgres: + # # If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank. + # host: localhost + # port: 5432 + # name: headscale + # user: foo + # pass: bar + # max_open_conns: 10 + # max_idle_conns: 10 + # conn_max_idle_time_secs: 3600 -# If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need -# in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1. -# db_ssl: false + # # If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need + # # in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1. + # ssl: false ### TLS configuration # diff --git a/docs/windows-client.md b/docs/windows-client.md index fcb8c0e..38d330b 100644 --- a/docs/windows-client.md +++ b/docs/windows-client.md @@ -18,6 +18,7 @@ You can set these using the Windows Registry Editor: Or via the following Powershell commands (right click Powershell icon and select "Run as administrator"): ``` +New-Item -Path "HKLM:\SOFTWARE\Tailscale IPN" New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value https://YOUR-HEADSCALE-URL ``` diff --git a/hscontrol/app.go b/hscontrol/app.go index 5327d6f..78b72bf 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -12,7 +12,6 @@ import ( "os" "os/signal" "runtime" - "strconv" "strings" "sync" "syscall" @@ -33,6 +32,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" + "github.com/pkg/profile" "github.com/prometheus/client_golang/prometheus/promhttp" zl "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -48,6 +48,7 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" + "gorm.io/gorm" "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" @@ -61,7 +62,7 @@ var ( "unknown value for Lets Encrypt challenge type", ) errEmptyInitialDERPMap = errors.New( - "initial DERPMap is empty, Headscale requries at least one entry", + "initial DERPMap is empty, Headscale requires at least one entry", ) ) @@ -116,37 +117,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) } - var dbString string - switch cfg.DBtype { - case db.Postgres: - dbString = fmt.Sprintf( - "host=%s dbname=%s user=%s", - cfg.DBhost, - cfg.DBname, - cfg.DBuser, - ) - - if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil { - if !sslEnabled { - dbString += " sslmode=disable" - } - } else { - dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl) - } - - if cfg.DBport != 0 { - dbString += fmt.Sprintf(" port=%d", cfg.DBport) - } - - if cfg.DBpass != "" { - dbString += fmt.Sprintf(" password=%s", cfg.DBpass) - } - case db.Sqlite: - dbString = cfg.DBpath - default: - return nil, errUnsupportedDatabase - } - registrationCache := cache.New( registerCacheExpiration, registerCacheCleanup, @@ -154,8 +124,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app := Headscale{ cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, noisePrivateKey: noisePrivateKey, registrationCache: registrationCache, pollNetMapStreamWG: sync.WaitGroup{}, @@ -163,9 +131,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } database, err := db.NewHeadscaleDatabase( - cfg.DBtype, - dbString, - app.dbDebug, + cfg.Database, app.nodeNotifier, cfg.IPPrefixes, cfg.BaseDomain) @@ -234,8 +200,23 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { // seen for longer than h.cfg.EphemeralNodeInactivityTimeout. func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) + + var update types.StateUpdate + var changed bool for range ticker.C { - h.db.ExpireEphemeralNodes(h.cfg.EphemeralNodeInactivityTimeout) + if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout) + + return nil + }); err != nil { + log.Error().Err(err).Msg("database error while expiring ephemeral nodes") + continue + } + + if changed && update.Valid() { + ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") + h.nodeNotifier.NotifyAll(ctx, update) + } } } @@ -246,9 +227,24 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) { ticker := time.NewTicker(interval) lastCheck := time.Unix(0, 0) + var update types.StateUpdate + var changed bool for range ticker.C { - lastCheck = h.db.ExpireExpiredNodes(lastCheck) + if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck) + + return nil + }); err != nil { + log.Error().Err(err).Msg("database error while expiring nodes") + continue + } + + log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes") + if changed && update.Valid() { + ctx := types.NotifyCtx(context.Background(), "expire-expired", "na") + h.nodeNotifier.NotifyAll(ctx, update) + } } } @@ -268,7 +264,7 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { case <-ticker.C: log.Info().Msg("Fetching DERPMap updates") h.DERPMap = derp.GetDERPMap(h.cfg.DERP) - if h.cfg.DERP.ServerEnabled { + if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { region, _ := h.DERPServer.GenerateRegion() h.DERPMap.Regions[region.RegionID] = ®ion } @@ -278,7 +274,8 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { DERPMap: h.DERPMap, } if stateUpdate.Valid() { - h.nodeNotifier.NotifyAll(stateUpdate) + ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na") + h.nodeNotifier.NotifyAll(ctx, stateUpdate) } } } @@ -485,6 +482,19 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { // Serve launches a GIN server with the Headscale API. func (h *Headscale) Serve() error { + if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile { + if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok { + err := os.MkdirAll(profilePath, os.ModePerm) + if err != nil { + log.Fatal().Err(err).Msg("failed to create profiling directory") + } + + defer profile.Start(profile.ProfilePath(profilePath)).Stop() + } else { + defer profile.Start().Stop() + } + } + var err error // Fetch an initial DERP Map before we start serving @@ -501,7 +511,9 @@ func (h *Headscale) Serve() error { return err } - h.DERPMap.Regions[region.RegionID] = ®ion + if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { + h.DERPMap.Regions[region.RegionID] = ®ion + } go h.DERPServer.ServeSTUN() } @@ -708,14 +720,16 @@ func (h *Headscale) Serve() error { var tailsqlContext context.Context if tailsqlEnabled { - if h.cfg.DBtype != db.Sqlite { - log.Fatal().Str("type", h.cfg.DBtype).Msgf("tailsql only support %q", db.Sqlite) + if h.cfg.Database.Type != types.DatabaseSqlite { + log.Fatal(). + Str("type", h.cfg.Database.Type). + Msgf("tailsql only support %q", types.DatabaseSqlite) } if tailsqlTSKey == "" { log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set") } tailsqlContext = context.Background() - go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.DBpath) + go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path) } // Handle common process-killing signals so we can gracefully shut down: @@ -751,7 +765,8 @@ func (h *Headscale) Serve() error { Str("path", aclPath). Msg("ACL policy successfully reloaded, notifying nodes of change") - h.nodeNotifier.NotifyAll(types.StateUpdate{ + ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na") + h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ Type: types.StateFullUpdate, }) } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 4fe5a16..3e9557a 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -1,6 +1,7 @@ package hscontrol import ( + "context" "encoding/json" "errors" "fmt" @@ -8,6 +9,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -199,6 +201,19 @@ func (h *Headscale) handleRegister( return } + // When logged out and reauthenticating with OIDC, the OldNodeKey is not passed, but the NodeKey has changed + if node.NodeKey.String() != registerRequest.NodeKey.String() && + registerRequest.OldNodeKey.IsZero() && !node.IsExpired() { + h.handleNodeKeyRefresh( + writer, + registerRequest, + *node, + machineKey, + ) + + return + } + if registerRequest.Followup != "" { select { case <-req.Context().Done(): @@ -230,8 +245,6 @@ func (h *Headscale) handleRegister( // handleAuthKey contains the logic to manage auth key client registration // When using Noise, the machineKey is Zero. -// -// TODO: check if any locks are needed around IP allocation. func (h *Headscale) handleAuthKey( writer http.ResponseWriter, registerRequest tailcfg.RegisterRequest, @@ -298,6 +311,9 @@ func (h *Headscale) handleAuthKey( nodeKey := registerRequest.NodeKey + var update types.StateUpdate + var mkey key.MachinePublic + // retrieve node information if it exist // The error is not important, because if it does not // exist, then this is a new node and we will move @@ -311,7 +327,7 @@ func (h *Headscale) handleAuthKey( node.NodeKey = nodeKey node.AuthKeyID = uint(pak.ID) - err := h.db.NodeSetExpiry(node, registerRequest.Expiry) + err := h.db.NodeSetExpiry(node.ID, registerRequest.Expiry) if err != nil { log.Error(). Caller(). @@ -322,10 +338,13 @@ func (h *Headscale) handleAuthKey( return } + mkey = node.MachineKey + update = types.StateUpdateExpire(node.ID, registerRequest.Expiry) + aclTags := pak.Proto().GetAclTags() if len(aclTags) > 0 { // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login - err = h.db.SetTags(node, aclTags) + err = h.db.SetTags(node.ID, aclTags) if err != nil { log.Error(). @@ -357,6 +376,7 @@ func (h *Headscale) handleAuthKey( Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, UserID: pak.User.ID, + User: pak.User, MachineKey: machineKey, RegisterMethod: util.RegisterMethodAuthKey, Expiry: ®isterRequest.Expiry, @@ -380,9 +400,18 @@ func (h *Headscale) handleAuthKey( return } + + mkey = node.MachineKey + update = types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from auth.handleAuthKey", + } } - err = h.db.UsePreAuthKey(pak) + err = h.db.DB.Transaction(func(tx *gorm.DB) error { + return db.UsePreAuthKey(tx, pak) + }) if err != nil { log.Error(). Caller(). @@ -424,6 +453,13 @@ func (h *Headscale) handleAuthKey( Caller(). Err(err). Msg("Failed to write response") + return + } + + // TODO(kradalby): if notifying after register make sense. + if update.Valid() { + ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na") + h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String()) } log.Info(). @@ -489,7 +525,7 @@ func (h *Headscale) handleNodeLogOut( Msg("Client requested logout") now := time.Now() - err := h.db.NodeSetExpiry(&node, now) + err := h.db.NodeSetExpiry(node.ID, now) if err != nil { log.Error(). Caller(). @@ -500,17 +536,10 @@ func (h *Headscale) handleNodeLogOut( return } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: &now, - }, - }, - } + stateUpdate := types.StateUpdateExpire(node.ID, now) if stateUpdate.Valid() { - h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) + ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na") + h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) } resp.AuthURL = "" @@ -541,7 +570,7 @@ func (h *Headscale) handleNodeLogOut( } if node.IsEphemeral() { - err = h.db.DeleteNode(&node) + err = h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap()) if err != nil { log.Error(). Err(err). @@ -549,6 +578,15 @@ func (h *Headscale) handleNodeLogOut( Msg("Cannot delete ephemeral node from the database") } + stateUpdate := types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na") + h.nodeNotifier.NotifyAll(ctx, stateUpdate) + } + return } @@ -620,7 +658,9 @@ func (h *Headscale) handleNodeKeyRefresh( Str("node", node.Hostname). Msg("We have the OldNodeKey in the database. This is a key refresh") - err := h.db.NodeSetNodeKey(&node, registerRequest.NodeKey) + err := h.db.DB.Transaction(func(tx *gorm.DB) error { + return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey) + }) if err != nil { log.Error(). Caller(). diff --git a/hscontrol/db/addresses.go b/hscontrol/db/addresses.go index beccf84..5857870 100644 --- a/hscontrol/db/addresses.go +++ b/hscontrol/db/addresses.go @@ -13,16 +13,23 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" + "gorm.io/gorm" ) var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) { + return getAvailableIPs(rx, hsdb.ipPrefixes) + }) +} + +func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) { var ips types.NodeAddresses var err error - for _, ipPrefix := range hsdb.ipPrefixes { + for _, ipPrefix := range ipPrefixes { var ip *netip.Addr - ip, err = hsdb.getAvailableIP(ipPrefix) + ip, err = getAvailableIP(rx, ipPrefix) if err != nil { return ips, err } @@ -32,8 +39,8 @@ func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) { return ips, err } -func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) { - usedIps, err := hsdb.getUsedIPs() +func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) { + usedIps, err := getUsedIPs(rx) if err != nil { return nil, err } @@ -64,12 +71,12 @@ func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, erro } } -func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) { +func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) { // FIXME: This really deserves a better data model, // but this was quick to get running and it should be enough // to begin experimenting with a dual stack tailnet. var addressesSlices []string - hsdb.db.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices) + rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices) var ips netipx.IPSetBuilder for _, slice := range addressesSlices { diff --git a/hscontrol/db/addresses_test.go b/hscontrol/db/addresses_test.go index 07059ea..ef33659 100644 --- a/hscontrol/db/addresses_test.go +++ b/hscontrol/db/addresses_test.go @@ -7,10 +7,16 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" "gopkg.in/check.v1" + "gorm.io/gorm" ) func (s *Suite) TestGetAvailableIp(c *check.C) { - ips, err := db.getAvailableIPs() + tx := db.DB.Begin() + defer tx.Rollback() + + ips, err := getAvailableIPs(tx, []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + }) c.Assert(err, check.IsNil) @@ -30,7 +36,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) node := types.Node{ @@ -41,10 +47,13 @@ func (s *Suite) TestGetUsedIps(c *check.C) { AuthKeyID: uint(pak.ID), IPAddresses: ips, } - db.db.Save(&node) - - usedIps, err := db.getUsedIPs() + db.Write(func(tx *gorm.DB) error { + return tx.Save(&node).Error + }) + usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) { + return getUsedIPs(rx) + }) c.Assert(err, check.IsNil) expected := netip.MustParseAddr("10.27.0.1") @@ -63,19 +72,23 @@ func (s *Suite) TestGetUsedIps(c *check.C) { } func (s *Suite) TestGetMultiIp(c *check.C) { - user, err := db.CreateUser("test-ip-multi") + user, err := db.CreateUser("test-ip") c.Assert(err, check.IsNil) + ipPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + } + for index := 1; index <= 350; index++ { - db.ipAllocationMutex.Lock() + tx := db.DB.Begin() - ips, err := db.getAvailableIPs() + ips, err := getAvailableIPs(tx, ipPrefixes) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = getNode(tx, "test", "testnode") c.Assert(err, check.NotNil) node := types.Node{ @@ -86,12 +99,13 @@ func (s *Suite) TestGetMultiIp(c *check.C) { AuthKeyID: uint(pak.ID), IPAddresses: ips, } - db.db.Save(&node) - - db.ipAllocationMutex.Unlock() + tx.Save(&node) + c.Assert(tx.Commit().Error, check.IsNil) } - usedIps, err := db.getUsedIPs() + usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) { + return getUsedIPs(rx) + }) c.Assert(err, check.IsNil) expected0 := netip.MustParseAddr("10.27.0.1") @@ -162,7 +176,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) node := types.Node{ @@ -172,7 +186,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) ips2, err := db.getAvailableIPs() c.Assert(err, check.IsNil) diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index bc8dc2b..5108314 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -22,9 +22,6 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, ) (string, *types.APIKey, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) if err != nil { return "", nil, err @@ -49,7 +46,7 @@ func (hsdb *HSDatabase) CreateAPIKey( Expiration: expiration, } - if err := hsdb.db.Save(&key).Error; err != nil { + if err := hsdb.DB.Save(&key).Error; err != nil { return "", nil, fmt.Errorf("failed to save API key to database: %w", err) } @@ -58,11 +55,8 @@ func (hsdb *HSDatabase) CreateAPIKey( // ListAPIKeys returns the list of ApiKeys for a user. func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - keys := []types.APIKey{} - if err := hsdb.db.Find(&keys).Error; err != nil { + if err := hsdb.DB.Find(&keys).Error; err != nil { return nil, err } @@ -71,11 +65,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { // GetAPIKey returns a ApiKey for a given key. func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - key := types.APIKey{} - if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { + if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error } @@ -84,11 +75,8 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { // GetAPIKeyByID returns a ApiKey for a given id. func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - key := types.APIKey{} - if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { + if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error } @@ -98,10 +86,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // does not exist. func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { + if result := hsdb.DB.Unscoped().Delete(key); result.Error != nil { return result.Error } @@ -110,10 +95,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { // ExpireAPIKey marks a ApiKey as expired. func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { + if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -121,9 +103,6 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { } func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - prefix, hash, found := strings.Cut(keyStr, ".") if !found { return false, ErrAPIKeyFailedToParse diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 030a6f0..4ded07f 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -6,24 +6,20 @@ import ( "errors" "fmt" "net/netip" + "strconv" "strings" - "sync" "time" "github.com/glebarez/sqlite" "github.com/go-gormigrate/gormigrate/v2" - "github.com/juanfont/headscale/hscontrol/notifier" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" -) -const ( - Postgres = "postgres" - Sqlite = "sqlite3" + "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" ) var errDatabaseNotSupported = errors.New("database type not supported") @@ -36,12 +32,7 @@ type KV struct { } type HSDatabase struct { - db *gorm.DB - notifier *notifier.Notifier - - mu sync.RWMutex - - ipAllocationMutex sync.Mutex + DB *gorm.DB ipPrefixes []netip.Prefix baseDomain string @@ -50,275 +41,290 @@ type HSDatabase struct { // TODO(kradalby): assemble this struct from toptions or something typed // rather than arguments. func NewHeadscaleDatabase( - dbType, connectionAddr string, - debug bool, + cfg types.DatabaseConfig, notifier *notifier.Notifier, ipPrefixes []netip.Prefix, baseDomain string, ) (*HSDatabase, error) { - dbConn, err := openDB(dbType, connectionAddr, debug) + dbConn, err := openDB(cfg) if err != nil { return nil, err } - migrations := gormigrate.New(dbConn, gormigrate.DefaultOptions, []*gormigrate.Migration{ - // New migrations should be added as transactions at the end of this list. - // The initial commit here is quite messy, completely out of order and - // has no versioning and is the tech debt of not having versioned migrations - // prior to this point. This first migration is all DB changes to bring a DB - // up to 0.23.0. - { - ID: "202312101416", - Migrate: func(tx *gorm.DB) error { - if dbType == Postgres { - tx.Exec(`create extension if not exists "uuid-ossp";`) - } - - _ = tx.Migrator().RenameTable("namespaces", "users") - - // the big rename from Machine to Node - _ = tx.Migrator().RenameTable("machines", "nodes") - _ = tx.Migrator().RenameColumn(&types.Route{}, "machine_id", "node_id") - - err = tx.AutoMigrate(types.User{}) - if err != nil { - return err - } - - _ = tx.Migrator().RenameColumn(&types.Node{}, "namespace_id", "user_id") - _ = tx.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id") - - _ = tx.Migrator().RenameColumn(&types.Node{}, "ip_address", "ip_addresses") - _ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname") - - // GivenName is used as the primary source of DNS names, make sure - // the field is populated and normalized if it was not when the - // node was registered. - _ = tx.Migrator().RenameColumn(&types.Node{}, "nickname", "given_name") - - // If the Node table has a column for registered, - // find all occourences of "false" and drop them. Then - // remove the column. - if tx.Migrator().HasColumn(&types.Node{}, "registered") { - log.Info(). - Msg(`Database has legacy "registered" column in node, removing...`) - - nodes := types.Nodes{} - if err := tx.Not("registered").Find(&nodes).Error; err != nil { - log.Error().Err(err).Msg("Error accessing db") + migrations := gormigrate.New( + dbConn, + gormigrate.DefaultOptions, + []*gormigrate.Migration{ + // New migrations should be added as transactions at the end of this list. + // The initial commit here is quite messy, completely out of order and + // has no versioning and is the tech debt of not having versioned migrations + // prior to this point. This first migration is all DB changes to bring a DB + // up to 0.23.0. + { + ID: "202312101416", + Migrate: func(tx *gorm.DB) error { + if cfg.Type == types.DatabasePostgres { + tx.Exec(`create extension if not exists "uuid-ossp";`) } - for _, node := range nodes { - log.Info(). - Str("node", node.Hostname). - Str("machine_key", node.MachineKey.ShortString()). - Msg("Deleting unregistered node") - if err := tx.Delete(&types.Node{}, node.ID).Error; err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Str("machine_key", node.MachineKey.ShortString()). - Msg("Error deleting unregistered node") - } - } + _ = tx.Migrator().RenameTable("namespaces", "users") - err := tx.Migrator().DropColumn(&types.Node{}, "registered") - if err != nil { - log.Error().Err(err).Msg("Error dropping registered column") - } - } + // the big rename from Machine to Node + _ = tx.Migrator().RenameTable("machines", "nodes") + _ = tx.Migrator(). + RenameColumn(&types.Route{}, "machine_id", "node_id") - err = tx.AutoMigrate(&types.Route{}) - if err != nil { - return err - } - - err = tx.AutoMigrate(&types.Node{}) - if err != nil { - return err - } - - // Ensure all keys have correct prefixes - // https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35 - type result struct { - ID uint64 - MachineKey string - NodeKey string - DiscoKey string - } - var results []result - err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error - if err != nil { - return err - } - - for _, node := range results { - mKey := node.MachineKey - if !strings.HasPrefix(node.MachineKey, "mkey:") { - mKey = "mkey:" + node.MachineKey - } - nKey := node.NodeKey - if !strings.HasPrefix(node.NodeKey, "nodekey:") { - nKey = "nodekey:" + node.NodeKey - } - - dKey := node.DiscoKey - if !strings.HasPrefix(node.DiscoKey, "discokey:") { - dKey = "discokey:" + node.DiscoKey - } - - err := tx.Exec( - "UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id", - sql.Named("mKey", mKey), - sql.Named("nKey", nKey), - sql.Named("dKey", dKey), - sql.Named("id", node.ID), - ).Error + err = tx.AutoMigrate(types.User{}) if err != nil { return err } - } - if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") { - log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...") + _ = tx.Migrator(). + RenameColumn(&types.Node{}, "namespace_id", "user_id") + _ = tx.Migrator(). + RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id") - type NodeAux struct { - ID uint64 - EnabledRoutes types.IPPrefixes - } + _ = tx.Migrator(). + RenameColumn(&types.Node{}, "ip_address", "ip_addresses") + _ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname") - nodesAux := []NodeAux{} - err := tx.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error - if err != nil { - log.Fatal().Err(err).Msg("Error accessing db") - } - for _, node := range nodesAux { - for _, prefix := range node.EnabledRoutes { - if err != nil { + // GivenName is used as the primary source of DNS names, make sure + // the field is populated and normalized if it was not when the + // node was registered. + _ = tx.Migrator(). + RenameColumn(&types.Node{}, "nickname", "given_name") + + // If the Node table has a column for registered, + // find all occourences of "false" and drop them. Then + // remove the column. + if tx.Migrator().HasColumn(&types.Node{}, "registered") { + log.Info(). + Msg(`Database has legacy "registered" column in node, removing...`) + + nodes := types.Nodes{} + if err := tx.Not("registered").Find(&nodes).Error; err != nil { + log.Error().Err(err).Msg("Error accessing db") + } + + for _, node := range nodes { + log.Info(). + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Msg("Deleting unregistered node") + if err := tx.Delete(&types.Node{}, node.ID).Error; err != nil { log.Error(). Err(err). - Str("enabled_route", prefix.String()). - Msg("Error parsing enabled_route") - - continue + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Msg("Error deleting unregistered node") } + } - err = tx.Preload("Node"). - Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). - First(&types.Route{}). - Error - if err == nil { - log.Info(). - Str("enabled_route", prefix.String()). - Msg("Route already migrated to new table, skipping") + err := tx.Migrator().DropColumn(&types.Node{}, "registered") + if err != nil { + log.Error().Err(err).Msg("Error dropping registered column") + } + } - continue + err = tx.AutoMigrate(&types.Route{}) + if err != nil { + return err + } + + err = tx.AutoMigrate(&types.Node{}) + if err != nil { + return err + } + + // Ensure all keys have correct prefixes + // https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35 + type result struct { + ID uint64 + MachineKey string + NodeKey string + DiscoKey string + } + var results []result + err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes"). + Find(&results). + Error + if err != nil { + return err + } + + for _, node := range results { + mKey := node.MachineKey + if !strings.HasPrefix(node.MachineKey, "mkey:") { + mKey = "mkey:" + node.MachineKey + } + nKey := node.NodeKey + if !strings.HasPrefix(node.NodeKey, "nodekey:") { + nKey = "nodekey:" + node.NodeKey + } + + dKey := node.DiscoKey + if !strings.HasPrefix(node.DiscoKey, "discokey:") { + dKey = "discokey:" + node.DiscoKey + } + + err := tx.Exec( + "UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id", + sql.Named("mKey", mKey), + sql.Named("nKey", nKey), + sql.Named("dKey", dKey), + sql.Named("id", node.ID), + ).Error + if err != nil { + return err + } + } + + if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") { + log.Info(). + Msgf("Database has legacy enabled_routes column in node, migrating...") + + type NodeAux struct { + ID uint64 + EnabledRoutes types.IPPrefixes + } + + nodesAux := []NodeAux{} + err := tx.Table("nodes"). + Select("id, enabled_routes"). + Scan(&nodesAux). + Error + if err != nil { + log.Fatal().Err(err).Msg("Error accessing db") + } + for _, node := range nodesAux { + for _, prefix := range node.EnabledRoutes { + if err != nil { + log.Error(). + Err(err). + Str("enabled_route", prefix.String()). + Msg("Error parsing enabled_route") + + continue + } + + err = tx.Preload("Node"). + Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). + First(&types.Route{}). + Error + if err == nil { + log.Info(). + Str("enabled_route", prefix.String()). + Msg("Route already migrated to new table, skipping") + + continue + } + + route := types.Route{ + NodeID: node.ID, + Advertised: true, + Enabled: true, + Prefix: types.IPPrefix(prefix), + } + if err := tx.Create(&route).Error; err != nil { + log.Error().Err(err).Msg("Error creating route") + } else { + log.Info(). + Uint64("node_id", route.NodeID). + Str("prefix", prefix.String()). + Msg("Route migrated") + } } + } - route := types.Route{ - NodeID: node.ID, - Advertised: true, - Enabled: true, - Prefix: types.IPPrefix(prefix), - } - if err := tx.Create(&route).Error; err != nil { - log.Error().Err(err).Msg("Error creating route") - } else { - log.Info(). - Uint64("node_id", route.NodeID). - Str("prefix", prefix.String()). - Msg("Route migrated") + err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes") + if err != nil { + log.Error(). + Err(err). + Msg("Error dropping enabled_routes column") + } + } + + if tx.Migrator().HasColumn(&types.Node{}, "given_name") { + nodes := types.Nodes{} + if err := tx.Find(&nodes).Error; err != nil { + log.Error().Err(err).Msg("Error accessing db") + } + + for item, node := range nodes { + if node.GivenName == "" { + normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( + node.Hostname, + ) + if err != nil { + log.Error(). + Caller(). + Str("hostname", node.Hostname). + Err(err). + Msg("Failed to normalize node hostname in DB migration") + } + + err = tx.Model(nodes[item]).Updates(types.Node{ + GivenName: normalizedHostname, + }).Error + if err != nil { + log.Error(). + Caller(). + Str("hostname", node.Hostname). + Err(err). + Msg("Failed to save normalized node name in DB migration") + } } } } - err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes") + err = tx.AutoMigrate(&KV{}) if err != nil { - log.Error().Err(err).Msg("Error dropping enabled_routes column") - } - } - - if tx.Migrator().HasColumn(&types.Node{}, "given_name") { - nodes := types.Nodes{} - if err := tx.Find(&nodes).Error; err != nil { - log.Error().Err(err).Msg("Error accessing db") + return err } - for item, node := range nodes { - if node.GivenName == "" { - normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( - node.Hostname, - ) - if err != nil { - log.Error(). - Caller(). - Str("hostname", node.Hostname). - Err(err). - Msg("Failed to normalize node hostname in DB migration") - } - - err = tx.Model(nodes[item]).Updates(types.Node{ - GivenName: normalizedHostname, - }).Error - if err != nil { - log.Error(). - Caller(). - Str("hostname", node.Hostname). - Err(err). - Msg("Failed to save normalized node name in DB migration") - } - } + err = tx.AutoMigrate(&types.PreAuthKey{}) + if err != nil { + return err } - } - err = tx.AutoMigrate(&KV{}) - if err != nil { - return err - } + err = tx.AutoMigrate(&types.PreAuthKeyACLTag{}) + if err != nil { + return err + } - err = tx.AutoMigrate(&types.PreAuthKey{}) - if err != nil { - return err - } + _ = tx.Migrator().DropTable("shared_machines") - err = tx.AutoMigrate(&types.PreAuthKeyACLTag{}) - if err != nil { - return err - } + err = tx.AutoMigrate(&types.APIKey{}) + if err != nil { + return err + } - _ = tx.Migrator().DropTable("shared_machines") - - err = tx.AutoMigrate(&types.APIKey{}) - if err != nil { - return err - } - - return nil + return nil + }, + Rollback: func(tx *gorm.DB) error { + return nil + }, }, - Rollback: func(tx *gorm.DB) error { - return nil + { + // drop key-value table, it is not used, and has not contained + // useful data for a long time or ever. + ID: "202312101430", + Migrate: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("kvs") + }, + Rollback: func(tx *gorm.DB) error { + return nil + }, }, }, - { - // drop key-value table, it is not used, and has not contained - // useful data for a long time or ever. - ID: "202312101430", - Migrate: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("kvs") - }, - Rollback: func(tx *gorm.DB) error { - return nil - }, - }, - }) + ) if err = migrations.Migrate(); err != nil { log.Fatal().Err(err).Msgf("Migration failed: %v", err) } db := HSDatabase{ - db: dbConn, - notifier: notifier, + DB: dbConn, ipPrefixes: ipPrefixes, baseDomain: baseDomain, @@ -327,20 +333,19 @@ func NewHeadscaleDatabase( return &db, err } -func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { - log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database") - +func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { + // TODO(kradalby): Integrate this with zerolog var dbLogger logger.Interface - if debug { + if cfg.Debug { dbLogger = logger.Default } else { dbLogger = logger.Default.LogMode(logger.Silent) } - switch dbType { - case Sqlite: + switch cfg.Type { + case types.DatabaseSqlite: db, err := gorm.Open( - sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"), + sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: dbLogger, @@ -359,16 +364,51 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { return db, err - case Postgres: - return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{ + case types.DatabasePostgres: + dbString := fmt.Sprintf( + "host=%s dbname=%s user=%s", + cfg.Postgres.Host, + cfg.Postgres.Name, + cfg.Postgres.User, + ) + + if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { + if !sslEnabled { + dbString += " sslmode=disable" + } + } else { + dbString += fmt.Sprintf(" sslmode=%s", cfg.Postgres.Ssl) + } + + if cfg.Postgres.Port != 0 { + dbString += fmt.Sprintf(" port=%d", cfg.Postgres.Port) + } + + if cfg.Postgres.Pass != "" { + dbString += fmt.Sprintf(" password=%s", cfg.Postgres.Pass) + } + + db, err := gorm.Open(postgres.Open(dbString), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: dbLogger, }) + if err != nil { + return nil, err + } + + sqlDB, _ := db.DB() + sqlDB.SetMaxIdleConns(cfg.Postgres.MaxIdleConnections) + sqlDB.SetMaxOpenConns(cfg.Postgres.MaxOpenConnections) + sqlDB.SetConnMaxIdleTime( + time.Duration(cfg.Postgres.ConnMaxIdleTimeSecs) * time.Second, + ) + + return db, nil } return nil, fmt.Errorf( "database of type %s is not supported: %w", - dbType, + cfg.Type, errDatabaseNotSupported, ) } @@ -376,7 +416,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { func (hsdb *HSDatabase) PingDB(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - sqlDB, err := hsdb.db.DB() + sqlDB, err := hsdb.DB.DB() if err != nil { return err } @@ -385,10 +425,48 @@ func (hsdb *HSDatabase) PingDB(ctx context.Context) error { } func (hsdb *HSDatabase) Close() error { - db, err := hsdb.db.DB() + db, err := hsdb.DB.DB() if err != nil { return err } return db.Close() } + +func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error { + rx := hsdb.DB.Begin() + defer rx.Rollback() + return fn(rx) +} + +func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) { + rx := db.Begin() + defer rx.Rollback() + ret, err := fn(rx) + if err != nil { + var no T + return no, err + } + return ret, nil +} + +func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error { + tx := hsdb.DB.Begin() + defer tx.Rollback() + if err := fn(tx); err != nil { + return err + } + + return tx.Commit().Error +} + +func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) { + tx := db.Begin() + defer tx.Rollback() + ret, err := fn(tx) + if err != nil { + var no T + return no, err + } + return ret, tx.Commit().Error +} diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index ce535b9..a747429 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -34,22 +34,21 @@ var ( ) ) -// ListPeers returns all peers of node, regardless of any Policy or if the node is expired. func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listPeers(node) + return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + return ListPeers(rx, node) + }) } -func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) { +// ListPeers returns all peers of node, regardless of any Policy or if the node is expired. +func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) { log.Trace(). Caller(). Str("node", node.Hostname). Msg("Finding direct peers") nodes := types.Nodes{} - if err := hsdb.db. + if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -64,16 +63,15 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) { return nodes, nil } -func (hsdb *HSDatabase) ListNodes() ([]types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listNodes() +func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + return ListNodes(rx) + }) } -func (hsdb *HSDatabase) listNodes() ([]types.Node, error) { - nodes := []types.Node{} - if err := hsdb.db. +func ListNodes(tx *gorm.DB) (types.Nodes, error) { + nodes := types.Nodes{} + if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -85,16 +83,9 @@ func (hsdb *HSDatabase) listNodes() ([]types.Node, error) { return nodes, nil } -func (hsdb *HSDatabase) ListNodesByGivenName(givenName string) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listNodesByGivenName(givenName) -} - -func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, error) { +func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) { nodes := types.Nodes{} - if err := hsdb.db. + if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -106,12 +97,15 @@ func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, err return nodes, nil } -// GetNode finds a Node by name and user and returns the Node struct. -func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() +func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return getNode(rx, user, name) + }) +} - nodes, err := hsdb.ListNodesByUser(user) +// getNode finds a Node by name and user and returns the Node struct. +func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) { + nodes, err := ListNodesByUser(tx, user) if err != nil { return nil, err } @@ -125,34 +119,16 @@ func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) { return nil, ErrNodeNotFound } -// GetNodeByGivenName finds a Node by given name and user and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByGivenName( - user string, - givenName string, -) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - node := types.Node{} - if err := hsdb.db. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - Where("given_name = ?", givenName).First(&node).Error; err != nil { - return nil, err - } - - return nil, ErrNodeNotFound +func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByID(rx, id) + }) } // GetNodeByID finds a Node by ID and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) { mach := types.Node{} - if result := hsdb.db. + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -164,21 +140,19 @@ func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { return &mach, nil } -// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByMachineKey( - machineKey key.MachinePublic, -) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeByMachineKey(machineKey) +func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByMachineKey(rx, machineKey) + }) } -func (hsdb *HSDatabase) getNodeByMachineKey( +// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. +func GetNodeByMachineKey( + tx *gorm.DB, machineKey key.MachinePublic, ) (*types.Node, error) { mach := types.Node{} - if result := hsdb.db. + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -190,36 +164,24 @@ func (hsdb *HSDatabase) getNodeByMachineKey( return &mach, nil } -// GetNodeByNodeKey finds a Node by its current NodeKey. -func (hsdb *HSDatabase) GetNodeByNodeKey( +func (hsdb *HSDatabase) GetNodeByAnyKey( + machineKey key.MachinePublic, nodeKey key.NodePublic, + oldNodeKey key.NodePublic, ) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - node := types.Node{} - if result := hsdb.db. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - First(&node, "node_key = ?", - nodeKey.String()); result.Error != nil { - return nil, result.Error - } - - return &node, nil + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey) + }) } // GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByAnyKey( +// TODO(kradalby): see if we can remove this. +func GetNodeByAnyKey( + tx *gorm.DB, machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, ) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - node := types.Node{} - if result := hsdb.db. + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -234,60 +196,44 @@ func (hsdb *HSDatabase) GetNodeByAnyKey( return &node, nil } -func (hsdb *HSDatabase) NodeReloadFromDatabase(node *types.Node) error { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - if result := hsdb.db.Find(node).First(&node); result.Error != nil { - return result.Error - } - - return nil +func (hsdb *HSDatabase) SetTags( + nodeID uint64, + tags []string, +) error { + return hsdb.Write(func(tx *gorm.DB) error { + return SetTags(tx, nodeID, tags) + }) } // SetTags takes a Node struct pointer and update the forced tags. -func (hsdb *HSDatabase) SetTags( - node *types.Node, +func SetTags( + tx *gorm.DB, + nodeID uint64, tags []string, ) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - if len(tags) == 0 { return nil } - newTags := []string{} + newTags := types.StringList{} for _, tag := range tags { if !util.StringOrPrefixListContains(newTags, tag) { newTags = append(newTags, tag) } } - if err := hsdb.db.Model(node).Updates(types.Node{ - ForcedTags: newTags, - }).Error; err != nil { + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil { return fmt.Errorf("failed to update tags for node in the database: %w", err) } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from db.SetTags", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) - } - return nil } // RenameNode takes a Node struct and a new GivenName for the nodes // and renames it. -func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func RenameNode(tx *gorm.DB, + nodeID uint64, newName string, +) error { err := util.CheckForFQDNRules( newName, ) @@ -295,129 +241,74 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error { log.Error(). Caller(). Str("func", "RenameNode"). - Str("node", node.Hostname). + Uint64("nodeID", nodeID). Str("newName", newName). Err(err). Msg("failed to rename node") return err } - node.GivenName = newName - if err := hsdb.db.Model(node).Updates(types.Node{ - GivenName: newName, - }).Error; err != nil { + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { return fmt.Errorf("failed to rename node in the database: %w", err) } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from db.RenameNode", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) - } - return nil } +func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error { + return hsdb.Write(func(tx *gorm.DB) error { + return NodeSetExpiry(tx, nodeID, expiry) + }) +} + // NodeSetExpiry takes a Node struct and a new expiry time. -func (hsdb *HSDatabase) NodeSetExpiry(node *types.Node, expiry time.Time) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.nodeSetExpiry(node, expiry) +func NodeSetExpiry(tx *gorm.DB, + nodeID uint64, expiry time.Time, +) error { + return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error } -func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error { - if err := hsdb.db.Model(node).Updates(types.Node{ - Expiry: &expiry, - }).Error; err != nil { - return fmt.Errorf( - "failed to refresh node (update expiration) in the database: %w", - err, - ) - } - - node.Expiry = &expiry - - stateSelfUpdate := types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: types.Nodes{node}, - } - if stateSelfUpdate.Valid() { - hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey) - } - - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: &expiry, - }, - }, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) - } - - return nil +func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.MachinePublic]bool) error { + return hsdb.Write(func(tx *gorm.DB) error { + return DeleteNode(tx, node, isConnected) + }) } // DeleteNode deletes a Node from the database. -func (hsdb *HSDatabase) DeleteNode(node *types.Node) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.deleteNode(node) -} - -func (hsdb *HSDatabase) deleteNode(node *types.Node) error { - err := hsdb.deleteNodeRoutes(node) +// Caller is responsible for notifying all of change. +func DeleteNode(tx *gorm.DB, + node *types.Node, + isConnected map[key.MachinePublic]bool, +) error { + err := deleteNodeRoutes(tx, node, map[key.MachinePublic]bool{}) if err != nil { return err } // Unscoped causes the node to be fully removed from the database. - if err := hsdb.db.Unscoped().Delete(&node).Error; err != nil { + if err := tx.Unscoped().Delete(&node).Error; err != nil { return err } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - return nil } // UpdateLastSeen sets a node's last seen field indicating that we // have recently communicating with this node. -// This is mostly used to indicate if a node is online and is not -// extremely important to make sure is fully correct and to avoid -// holding up the hot path, does not contain any locks and isnt -// concurrency safe. But that should be ok. -func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error { - return hsdb.db.Model(node).Updates(types.Node{ - LastSeen: node.LastSeen, - }).Error +func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error { + return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error } -func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( +func RegisterNodeFromAuthCallback( + tx *gorm.DB, cache *cache.Cache, mkey key.MachinePublic, userName string, nodeExpiry *time.Time, registrationMethod string, + ipPrefixes []netip.Prefix, ) (*types.Node, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - log.Debug(). Str("machine_key", mkey.ShortString()). Str("userName", userName). @@ -427,7 +318,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( if nodeInterface, ok := cache.Get(mkey.String()); ok { if registrationNode, ok := nodeInterface.(types.Node); ok { - user, err := hsdb.getUser(userName) + user, err := GetUser(tx, userName) if err != nil { return nil, fmt.Errorf( "failed to find user in register node from auth callback, %w", @@ -442,14 +333,17 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( } registrationNode.UserID = user.ID + registrationNode.User = *user registrationNode.RegisterMethod = registrationMethod if nodeExpiry != nil { registrationNode.Expiry = nodeExpiry } - node, err := hsdb.registerNode( + node, err := RegisterNode( + tx, registrationNode, + ipPrefixes, ) if err == nil { @@ -465,15 +359,14 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( return nil, ErrNodeNotFoundRegistrationCache } -// RegisterNode is executed from the CLI to register a new Node using its MachineKey. func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.registerNode(node) + return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { + return RegisterNode(tx, node, hsdb.ipPrefixes) + }) } -func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { +// RegisterNode is executed from the CLI to register a new Node using its MachineKey. +func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) { log.Debug(). Str("node", node.Hostname). Str("machine_key", node.MachineKey.ShortString()). @@ -485,7 +378,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { // so we store the node.Expire and node.Nodekey that has been set when // adding it to the registrationCache if len(node.IPAddresses) > 0 { - if err := hsdb.db.Save(&node).Error; err != nil { + if err := tx.Save(&node).Error; err != nil { return nil, fmt.Errorf("failed register existing node in the database: %w", err) } @@ -500,10 +393,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { return &node, nil } - hsdb.ipAllocationMutex.Lock() - defer hsdb.ipAllocationMutex.Unlock() - - ips, err := hsdb.getAvailableIPs() + ips, err := getAvailableIPs(tx, ipPrefixes) if err != nil { log.Error(). Caller(). @@ -516,7 +406,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { node.IPAddresses = ips - if err := hsdb.db.Save(&node).Error; err != nil { + if err := tx.Save(&node).Error; err != nil { return nil, fmt.Errorf("failed register(save) node in the database: %w", err) } @@ -530,61 +420,50 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { } // NodeSetNodeKey sets the node key of a node and saves it to the database. -func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if err := hsdb.db.Model(node).Updates(types.Node{ +func NodeSetNodeKey(tx *gorm.DB, node *types.Node, nodeKey key.NodePublic) error { + return tx.Model(node).Updates(types.Node{ NodeKey: nodeKey, - }).Error; err != nil { - return err - } - - return nil + }).Error } -// NodeSetMachineKey sets the node key of a node and saves it to the database. func (hsdb *HSDatabase) NodeSetMachineKey( node *types.Node, machineKey key.MachinePublic, ) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + return hsdb.Write(func(tx *gorm.DB) error { + return NodeSetMachineKey(tx, node, machineKey) + }) +} - if err := hsdb.db.Model(node).Updates(types.Node{ +// NodeSetMachineKey sets the node key of a node and saves it to the database. +func NodeSetMachineKey( + tx *gorm.DB, + node *types.Node, + machineKey key.MachinePublic, +) error { + return tx.Model(node).Updates(types.Node{ MachineKey: machineKey, - }).Error; err != nil { - return err - } - - return nil + }).Error } // NodeSave saves a node object to the database, prefer to use a specific save method rather // than this. It is intended to be used when we are changing or. -func (hsdb *HSDatabase) NodeSave(node *types.Node) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() +// TODO(kradalby): Remove this func, just use Save. +func NodeSave(tx *gorm.DB, node *types.Node) error { + return tx.Save(node).Error +} - if err := hsdb.db.Save(node).Error; err != nil { - return err - } - - return nil +func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { + return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) { + return GetAdvertisedRoutes(rx, node) + }) } // GetAdvertisedRoutes returns the routes that are be advertised by the given node. -func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getAdvertisedRoutes(node) -} - -func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { +func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) { routes := types.Routes{} - err := hsdb.db. + err := tx. Preload("Node"). Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { @@ -605,18 +484,17 @@ func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, e return prefixes, nil } -// GetEnabledRoutes returns the routes that are enabled for the node. func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getEnabledRoutes(node) + return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) { + return GetEnabledRoutes(rx, node) + }) } -func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { +// GetEnabledRoutes returns the routes that are enabled for the node. +func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) { routes := types.Routes{} - err := hsdb.db. + err := tx. Preload("Node"). Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true). Find(&routes).Error @@ -638,16 +516,13 @@ func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, erro return prefixes, nil } -func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool { route, err := netip.ParsePrefix(routeStr) if err != nil { return false } - enabledRoutes, err := hsdb.getEnabledRoutes(node) + enabledRoutes, err := GetEnabledRoutes(tx, node) if err != nil { log.Error().Err(err).Msg("Could not get enabled routes") @@ -663,26 +538,37 @@ func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool return false } +func (hsdb *HSDatabase) enableRoutes( + node *types.Node, + routeStrs ...string, +) (*types.StateUpdate, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return enableRoutes(tx, node, routeStrs...) + }) +} + // enableRoutes enables new routes based on a list of new routes. -func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error { +func enableRoutes(tx *gorm.DB, + node *types.Node, routeStrs ...string, +) (*types.StateUpdate, error) { newRoutes := make([]netip.Prefix, len(routeStrs)) for index, routeStr := range routeStrs { route, err := netip.ParsePrefix(routeStr) if err != nil { - return err + return nil, err } newRoutes[index] = route } - advertisedRoutes, err := hsdb.getAdvertisedRoutes(node) + advertisedRoutes, err := GetAdvertisedRoutes(tx, node) if err != nil { - return err + return nil, err } for _, newRoute := range newRoutes { if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) { - return fmt.Errorf( + return nil, fmt.Errorf( "route (%s) is not available on node %s: %w", node.Hostname, newRoute, ErrNodeRouteIsNotAvailable, @@ -693,7 +579,7 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro // Separate loop so we don't leave things in a half-updated state for _, prefix := range newRoutes { route := types.Route{} - err := hsdb.db.Preload("Node"). + err := tx.Preload("Node"). Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). First(&route).Error if err == nil { @@ -702,23 +588,23 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro // Mark already as primary if there is only this node offering this subnet // (and is not an exit route) if !route.IsExitRoute() { - route.IsPrimary = hsdb.isUniquePrefix(route) + route.IsPrimary = isUniquePrefix(tx, route) } - err = hsdb.db.Save(&route).Error + err = tx.Save(&route).Error if err != nil { - return fmt.Errorf("failed to enable route: %w", err) + return nil, fmt.Errorf("failed to enable route: %w", err) } } else { - return fmt.Errorf("failed to find route: %w", err) + return nil, fmt.Errorf("failed to find route: %w", err) } } // Ensure the node has the latest routes when notifying the other // nodes - nRoutes, err := hsdb.getNodeRoutes(node) + nRoutes, err := GetNodeRoutes(tx, node) if err != nil { - return fmt.Errorf("failed to read back routes: %w", err) + return nil, fmt.Errorf("failed to read back routes: %w", err) } node.Routes = nRoutes @@ -729,17 +615,11 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro Strs("routes", routeStrs). Msg("enabling routes") - stateUpdate := types.StateUpdate{ + return &types.StateUpdate{ Type: types.StatePeerChanged, ChangeNodes: types.Nodes{node}, - Message: "called from db.enableRoutes", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore( - stateUpdate, node.MachineKey.String()) - } - - return nil + Message: "created in db.enableRoutes", + }, nil } func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { @@ -772,16 +652,23 @@ func (hsdb *HSDatabase) GenerateGivenName( mkey key.MachinePublic, suppliedName string, ) (string, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() + return Read(hsdb.DB, func(rx *gorm.DB) (string, error) { + return GenerateGivenName(rx, mkey, suppliedName) + }) +} +func GenerateGivenName( + tx *gorm.DB, + mkey key.MachinePublic, + suppliedName string, +) (string, error) { givenName, err := generateGivenName(suppliedName, false) if err != nil { return "", err } // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - nodes, err := hsdb.listNodesByGivenName(givenName) + nodes, err := listNodesByGivenName(tx, givenName) if err != nil { return "", err } @@ -805,29 +692,28 @@ func (hsdb *HSDatabase) GenerateGivenName( return givenName, nil } -func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - users, err := hsdb.listUsers() +func ExpireEphemeralNodes(tx *gorm.DB, + inactivityThreshhold time.Duration, +) (types.StateUpdate, bool) { + users, err := ListUsers(tx) if err != nil { log.Error().Err(err).Msg("Error listing users") - return + return types.StateUpdate{}, false } + expired := make([]tailcfg.NodeID, 0) for _, user := range users { - nodes, err := hsdb.listNodesByUser(user.Name) + nodes, err := ListNodesByUser(tx, user.Name) if err != nil { log.Error(). Err(err). Str("user", user.Name). Msg("Error listing nodes in user") - return + return types.StateUpdate{}, false } - expired := make([]tailcfg.NodeID, 0) for idx, node := range nodes { if node.IsEphemeral() && node.LastSeen != nil && time.Now(). @@ -838,7 +724,8 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) Str("node", node.Hostname). Msg("Ephemeral client removed from database") - err = hsdb.deleteNode(nodes[idx]) + // empty isConnected map as ephemeral nodes are not routes + err = DeleteNode(tx, nodes[idx], map[key.MachinePublic]bool{}) if err != nil { log.Error(). Err(err). @@ -848,33 +735,35 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) } } - if len(expired) > 0 { - hsdb.notifier.NotifyAll(types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: expired, - }) - } + // TODO(kradalby): needs to be moved out of transaction } + if len(expired) > 0 { + return types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: expired, + }, true + } + + return types.StateUpdate{}, false } -func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func ExpireExpiredNodes(tx *gorm.DB, + lastCheck time.Time, +) (time.Time, types.StateUpdate, bool) { // use the time of the start of the function to ensure we // dont miss some nodes by returning it _after_ we have // checked everything. started := time.Now() - expiredNodes := make([]*types.Node, 0) + expired := make([]*tailcfg.PeerChange, 0) - nodes, err := hsdb.listNodes() + nodes, err := ListNodes(tx) if err != nil { log.Error(). Err(err). Msg("Error listing nodes to find expired nodes") - return time.Unix(0, 0) + return time.Unix(0, 0), types.StateUpdate{}, false } for index, node := range nodes { if node.IsExpired() && @@ -882,13 +771,17 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { // It will notify about all nodes that has been expired. // It should only notify about expired nodes since _last check_. node.Expiry.After(lastCheck) { - expiredNodes = append(expiredNodes, &nodes[index]) + expired = append(expired, &tailcfg.PeerChange{ + NodeID: tailcfg.NodeID(node.ID), + KeyExpiry: node.Expiry, + }) + now := time.Now() // Do not use setNodeExpiry as that has a notifier hook, which // can cause a deadlock, we are updating all changed nodes later // and there is no point in notifiying twice. - if err := hsdb.db.Model(nodes[index]).Updates(types.Node{ - Expiry: &started, + if err := tx.Model(&nodes[index]).Updates(types.Node{ + Expiry: &now, }).Error; err != nil { log.Error(). Err(err). @@ -904,33 +797,12 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { } } - expired := make([]*tailcfg.PeerChange, len(expiredNodes)) - for idx, node := range expiredNodes { - expired[idx] = &tailcfg.PeerChange{ - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: &started, - } + if len(expired) > 0 { + return started, types.StateUpdate{ + Type: types.StatePeerChangedPatch, + ChangePatches: expired, + }, true } - // Inform the peers of a node with a lightweight update. - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: expired, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - - // Inform the node itself that it has expired. - for _, node := range expiredNodes { - stateSelfUpdate := types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: types.Nodes{node}, - } - if stateSelfUpdate.Valid() { - hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey) - } - } - - return started + return started, types.StateUpdate{}, false } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 140c264..5e8eb29 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -23,7 +23,7 @@ func (s *Suite) TestGetNode(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -38,9 +38,9 @@ func (s *Suite) TestGetNode(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(node) + db.DB.Save(node) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) } @@ -66,40 +66,12 @@ func (s *Suite) TestGetNodeByID(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) _, err = db.GetNodeByID(0) c.Assert(err, check.IsNil) } -func (s *Suite) TestGetNodeByNodeKey(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - db.db.Save(&node) - - _, err = db.GetNodeByNodeKey(nodeKey.Public()) - c.Assert(err, check.IsNil) -} - func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) @@ -124,7 +96,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) _, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) c.Assert(err, check.IsNil) @@ -146,12 +118,12 @@ func (s *Suite) TestHardDeleteNode(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(1), } - db.db.Save(&node) + db.DB.Save(&node) - err = db.DeleteNode(&node) + err = db.DeleteNode(&node, map[key.MachinePublic]bool{}) c.Assert(err, check.IsNil) - _, err = db.GetNode(user.Name, "testnode3") + _, err = db.getNode(user.Name, "testnode3") c.Assert(err, check.NotNil) } @@ -178,7 +150,7 @@ func (s *Suite) TestListPeers(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) } node0ByID, err := db.GetNodeByID(0) @@ -228,7 +200,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(stor[index%2].key.ID), } - db.db.Save(&node) + db.DB.Save(&node) } aclPolicy := &policy.ACLPolicy{ @@ -295,7 +267,7 @@ func (s *Suite) TestExpireNode(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -311,16 +283,19 @@ func (s *Suite) TestExpireNode(c *check.C) { AuthKeyID: uint(pak.ID), Expiry: &time.Time{}, } - db.db.Save(node) + db.DB.Save(node) - nodeFromDB, err := db.GetNode("test", "testnode") + nodeFromDB, err := db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB, check.NotNil) c.Assert(nodeFromDB.IsExpired(), check.Equals, false) now := time.Now() - err = db.NodeSetExpiry(nodeFromDB, now) + err = db.NodeSetExpiry(nodeFromDB.ID, now) + c.Assert(err, check.IsNil) + + nodeFromDB, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB.IsExpired(), check.Equals, true) @@ -354,7 +329,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("user-1", "testnode") + _, err = db.getNode("user-1", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -372,7 +347,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(node) + db.DB.Save(node) givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2") comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") @@ -397,7 +372,7 @@ func (s *Suite) TestSetTags(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -412,21 +387,21 @@ func (s *Suite) TestSetTags(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(node) + db.DB.Save(node) // assign simple tags sTags := []string{"tag:test", "tag:foo"} - err = db.SetTags(node, sTags) + err = db.SetTags(node.ID, sTags) c.Assert(err, check.IsNil) - node, err = db.GetNode("test", "testnode") + node, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags)) // assign duplicat tags, expect no errors but no doubles in DB eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} - err = db.SetTags(node, eTags) + err = db.SetTags(node.ID, eTags) c.Assert(err, check.IsNil) - node, err = db.GetNode("test", "testnode") + node, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert( node.ForcedTags, @@ -601,7 +576,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, } - db.db.Save(&node) + db.DB.Save(&node) sendUpdate, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -610,7 +585,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { node0ByID, err := db.GetNodeByID(0) c.Assert(err, check.IsNil) - err = db.EnableAutoApprovedRoutes(pol, node0ByID) + // TODO(kradalby): Check state update + _, err = db.EnableAutoApprovedRoutes(pol, node0ByID) c.Assert(err, check.IsNil) enabledRoutes, err := db.GetEnabledRoutes(node0ByID) diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index e743988..0fdb822 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -20,7 +20,6 @@ var ( ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") ) -// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. func (hsdb *HSDatabase) CreatePreAuthKey( userName string, reusable bool, @@ -28,11 +27,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey( expiration *time.Time, aclTags []string, ) (*types.PreAuthKey, error) { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { + return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags) + }) +} - user, err := hsdb.GetUser(userName) +// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. +func CreatePreAuthKey( + tx *gorm.DB, + userName string, + reusable bool, + ephemeral bool, + expiration *time.Time, + aclTags []string, +) (*types.PreAuthKey, error) { + user, err := GetUser(tx, userName) if err != nil { return nil, err } @@ -48,7 +57,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( } now := time.Now().UTC() - kstr, err := hsdb.generateKey() + kstr, err := generateKey() if err != nil { return nil, err } @@ -63,29 +72,25 @@ func (hsdb *HSDatabase) CreatePreAuthKey( Expiration: expiration, } - err = hsdb.db.Transaction(func(db *gorm.DB) error { - if err := db.Save(&key).Error; err != nil { - return fmt.Errorf("failed to create key in the database: %w", err) - } + if err := tx.Save(&key).Error; err != nil { + return nil, fmt.Errorf("failed to create key in the database: %w", err) + } - if len(aclTags) > 0 { - seenTags := map[string]bool{} + if len(aclTags) > 0 { + seenTags := map[string]bool{} - for _, tag := range aclTags { - if !seenTags[tag] { - if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { - return fmt.Errorf( - "failed to ceate key tag in the database: %w", - err, - ) - } - seenTags[tag] = true + for _, tag := range aclTags { + if !seenTags[tag] { + if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { + return nil, fmt.Errorf( + "failed to ceate key tag in the database: %w", + err, + ) } + seenTags[tag] = true } } - - return nil - }) + } if err != nil { return nil, err @@ -94,22 +99,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey( return &key, nil } -// ListPreAuthKeys returns the list of PreAuthKeys for a user. func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listPreAuthKeys(userName) + return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { + return ListPreAuthKeys(rx, userName) + }) } -func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) { - user, err := hsdb.getUser(userName) +// ListPreAuthKeys returns the list of PreAuthKeys for a user. +func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) { + user, err := GetUser(tx, userName) if err != nil { return nil, err } keys := []types.PreAuthKey{} - if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { + if err := tx.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { return nil, err } @@ -117,11 +121,8 @@ func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, er } // GetPreAuthKey returns a PreAuthKey for a given key. -func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - pak, err := hsdb.ValidatePreAuthKey(key) +func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) { + pak, err := ValidatePreAuthKey(tx, key) if err != nil { return nil, err } @@ -135,15 +136,8 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. -func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.destroyPreAuthKey(pak) -} - -func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { - return hsdb.db.Transaction(func(db *gorm.DB) error { +func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error { + return tx.Transaction(func(db *gorm.DB) error { if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { return result.Error } @@ -156,12 +150,15 @@ func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { }) } -// MarkExpirePreAuthKey marks a PreAuthKey as expired. func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + return hsdb.Write(func(tx *gorm.DB) error { + return ExpirePreAuthKey(tx, k) + }) +} - if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { +// MarkExpirePreAuthKey marks a PreAuthKey as expired. +func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { + if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -169,26 +166,26 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { } // UsePreAuthKey marks a PreAuthKey as used. -func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { k.Used = true - if err := hsdb.db.Save(k).Error; err != nil { + if err := tx.Save(k).Error; err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) } return nil } +func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) { + return ValidatePreAuthKey(rx, k) + }) +} + // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. -func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) { pak := types.PreAuthKey{} - if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( + if result := tx.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -204,7 +201,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) } nodes := types.Nodes{} - if err := hsdb.db. + if err := tx. Preload("AuthKey"). Where(&types.Node{AuthKeyID: uint(pak.ID)}). Find(&nodes).Error; err != nil { @@ -218,7 +215,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) return &pak, nil } -func (hsdb *HSDatabase) generateKey() (string, error) { +func generateKey() (string, error) { size := 24 bytes := make([]byte, size) if _, err := rand.Read(bytes); err != nil { diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index df9c2a1..003a396 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -6,6 +6,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" + "gorm.io/gorm" ) func (*Suite) TestCreatePreAuthKey(c *check.C) { @@ -41,7 +42,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) { user, err := db.CreateUser("test2") c.Assert(err, check.IsNil) - now := time.Now() + now := time.Now().Add(-5 * time.Second) pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) c.Assert(err, check.IsNil) @@ -82,7 +83,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) @@ -103,7 +104,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) @@ -138,19 +139,22 @@ func (*Suite) TestEphemeralKey(c *check.C) { LastSeen: &now, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) _, err = db.ValidatePreAuthKey(pak.Key) // Ephemeral keys are by definition reusable c.Assert(err, check.IsNil) - _, err = db.GetNode("test7", "testest") + _, err = db.getNode("test7", "testest") c.Assert(err, check.IsNil) - db.ExpireEphemeralNodes(time.Second * 20) + db.DB.Transaction(func(tx *gorm.DB) error { + ExpireEphemeralNodes(tx, time.Second*20) + return nil + }) // The machine record should have been deleted - _, err = db.GetNode("test7", "testest") + _, err = db.getNode("test7", "testest") c.Assert(err, check.NotNil) } @@ -178,7 +182,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true - db.db.Save(&pak) + db.DB.Save(&pak) _, err = db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 51c7f3b..1ee144a 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -7,23 +7,15 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" - "github.com/samber/lo" "gorm.io/gorm" "tailscale.com/types/key" ) var ErrRouteIsNotAvailable = errors.New("route is not available") -func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getRoutes() -} - -func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { +func GetRoutes(tx *gorm.DB) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Find(&routes).Error @@ -34,9 +26,9 @@ func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { return routes, nil } -func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) { +func getAdvertisedAndEnabledRoutes(tx *gorm.DB) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("advertised = ? AND enabled = ?", true, true). @@ -48,9 +40,9 @@ func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) { return routes, nil } -func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) { +func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("prefix = ?", types.IPPrefix(pref)). @@ -62,16 +54,9 @@ func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, erro return routes, nil } -func (hsdb *HSDatabase) GetNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeAdvertisedRoutes(node) -} - -func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) { +func GetNodeAdvertisedRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("node_id = ? AND advertised = true", node.ID). @@ -84,15 +69,14 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, } func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeRoutes(node) + return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) { + return GetNodeRoutes(rx, node) + }) } -func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) { +func GetNodeRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("node_id = ?", node.ID). @@ -104,16 +88,9 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) { return routes, nil } -func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getRoute(id) -} - -func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { +func GetRoute(tx *gorm.DB, id uint64) (*types.Route, error) { var route types.Route - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). First(&route, id).Error @@ -124,40 +101,34 @@ func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { return &route, nil } -func (hsdb *HSDatabase) EnableRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.enableRoute(id) -} - -func (hsdb *HSDatabase) enableRoute(id uint64) error { - route, err := hsdb.getRoute(id) +func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) { + route, err := GetRoute(tx, id) if err != nil { - return err + return nil, err } // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 if route.IsExitRoute() { - return hsdb.enableRoutes( + return enableRoutes( + tx, &route.Node, types.ExitRouteV4.String(), types.ExitRouteV6.String(), ) } - return hsdb.enableRoutes(&route.Node, netip.Prefix(route.Prefix).String()) + return enableRoutes(tx, &route.Node, netip.Prefix(route.Prefix).String()) } -func (hsdb *HSDatabase) DisableRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - route, err := hsdb.getRoute(id) +func DisableRoute(tx *gorm.DB, + id uint64, + isConnected map[key.MachinePublic]bool, +) (*types.StateUpdate, error) { + route, err := GetRoute(tx, id) if err != nil { - return err + return nil, err } var routes types.Routes @@ -166,64 +137,79 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 + var update *types.StateUpdate if !route.IsExitRoute() { - err = hsdb.failoverRouteWithNotify(route) + update, err = failoverRouteReturnUpdate(tx, isConnected, route) if err != nil { - return err + return nil, err } route.Enabled = false route.IsPrimary = false - err = hsdb.db.Save(route).Error + err = tx.Save(route).Error if err != nil { - return err + return nil, err } } else { - routes, err = hsdb.getNodeRoutes(&node) + routes, err = GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } for i := range routes { if routes[i].IsExitRoute() { routes[i].Enabled = false routes[i].IsPrimary = false - err = hsdb.db.Save(&routes[i]).Error + err = tx.Save(&routes[i]).Error if err != nil { - return err + return nil, err } } } } if routes == nil { - routes, err = hsdb.getNodeRoutes(&node) + routes, err = GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } } node.Routes = routes - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{&node}, - Message: "called from db.DisableRoute", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) + // If update is empty, it means that one was not created + // by failover (as a failover was not necessary), create + // one and return to the caller. + if update == nil { + update = &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{ + &node, + }, + Message: "called from db.DisableRoute", + } } - return nil + return update, nil } -func (hsdb *HSDatabase) DeleteRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() +func (hsdb *HSDatabase) DeleteRoute( + id uint64, + isConnected map[key.MachinePublic]bool, +) (*types.StateUpdate, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return DeleteRoute(tx, id, isConnected) + }) +} - route, err := hsdb.getRoute(id) +func DeleteRoute( + tx *gorm.DB, + id uint64, + isConnected map[key.MachinePublic]bool, +) (*types.StateUpdate, error) { + route, err := GetRoute(tx, id) if err != nil { - return err + return nil, err } var routes types.Routes @@ -232,19 +218,20 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 + var update *types.StateUpdate if !route.IsExitRoute() { - err := hsdb.failoverRouteWithNotify(route) + update, err = failoverRouteReturnUpdate(tx, isConnected, route) if err != nil { - return nil + return nil, nil } - if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { - return err + if err := tx.Unscoped().Delete(&route).Error; err != nil { + return nil, err } } else { - routes, err := hsdb.getNodeRoutes(&node) + routes, err := GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } routesToDelete := types.Routes{} @@ -254,56 +241,59 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { } } - if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil { - return err + if err := tx.Unscoped().Delete(&routesToDelete).Error; err != nil { + return nil, err } } + // If update is empty, it means that one was not created + // by failover (as a failover was not necessary), create + // one and return to the caller. if routes == nil { - routes, err = hsdb.getNodeRoutes(&node) + routes, err = GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } } node.Routes = routes - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{&node}, - Message: "called from db.DeleteRoute", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) + if update == nil { + update = &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{ + &node, + }, + Message: "called from db.DeleteRoute", + } } - return nil + return update, nil } -func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error { - routes, err := hsdb.getNodeRoutes(node) +func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error { + routes, err := GetNodeRoutes(tx, node) if err != nil { return err } for i := range routes { - if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil { + if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil { return err } // TODO(kradalby): This is a bit too aggressive, we could probably // figure out which routes needs to be failed over rather than all. - hsdb.failoverRouteWithNotify(&routes[i]) + failoverRouteReturnUpdate(tx, isConnected, &routes[i]) } return nil } // isUniquePrefix returns if there is another node providing the same route already. -func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { +func isUniquePrefix(tx *gorm.DB, route types.Route) bool { var count int64 - hsdb.db. - Model(&types.Route{}). + tx.Model(&types.Route{}). Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?", route.Prefix, route.NodeID, @@ -312,9 +302,9 @@ func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { return count == 0 } -func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) { +func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) { var route types.Route - err := hsdb.db. + err := tx. Preload("Node"). Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true). First(&route).Error @@ -329,14 +319,17 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro return &route, nil } +func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) { + return GetNodePrimaryRoutes(rx, node) + }) +} + // getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary. -func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func GetNodePrimaryRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true). Find(&routes).Error @@ -347,22 +340,21 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er return routes, nil } -// SaveNodeRoutes takes a node and updates the database with -// the new routes. -// It returns a bool wheter an update should be sent as the -// saved route impacts nodes. func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.saveNodeRoutes(node) + return Write(hsdb.DB, func(tx *gorm.DB) (bool, error) { + return SaveNodeRoutes(tx, node) + }) } -func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { +// SaveNodeRoutes takes a node and updates the database with +// the new routes. +// It returns a bool whether an update should be sent as the +// saved route impacts nodes. +func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) { sendUpdate := false currentRoutes := types.Routes{} - err := hsdb.db.Where("node_id = ?", node.ID).Find(¤tRoutes).Error + err := tx.Where("node_id = ?", node.ID).Find(¤tRoutes).Error if err != nil { return sendUpdate, err } @@ -382,7 +374,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { if !route.Advertised { currentRoutes[pos].Advertised = true - err := hsdb.db.Save(¤tRoutes[pos]).Error + err := tx.Save(¤tRoutes[pos]).Error if err != nil { return sendUpdate, err } @@ -398,7 +390,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { } else if route.Advertised { currentRoutes[pos].Advertised = false currentRoutes[pos].Enabled = false - err := hsdb.db.Save(¤tRoutes[pos]).Error + err := tx.Save(¤tRoutes[pos]).Error if err != nil { return sendUpdate, err } @@ -413,7 +405,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { Advertised: true, Enabled: false, } - err := hsdb.db.Create(&route).Error + err := tx.Create(&route).Error if err != nil { return sendUpdate, err } @@ -425,127 +417,89 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { // EnsureFailoverRouteIsAvailable takes a node and checks if the node's route // currently have a functioning host that exposes the network. -func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error { - nodeRoutes, err := hsdb.getNodeRoutes(node) +func EnsureFailoverRouteIsAvailable( + tx *gorm.DB, + isConnected map[key.MachinePublic]bool, + node *types.Node, +) (*types.StateUpdate, error) { + nodeRoutes, err := GetNodeRoutes(tx, node) if err != nil { - return nil + return nil, nil } + var changedNodes types.Nodes for _, nodeRoute := range nodeRoutes { - routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix)) + routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix)) if err != nil { - return err + return nil, err } for _, route := range routes { if route.IsPrimary { // if we have a primary route, and the node is connected // nothing needs to be done. - if hsdb.notifier.IsConnected(route.Node.MachineKey) { + if isConnected[route.Node.MachineKey] { continue } // if not, we need to failover the route - err := hsdb.failoverRouteWithNotify(&route) + update, err := failoverRouteReturnUpdate(tx, isConnected, &route) if err != nil { - return err + return nil, err + } + + if update != nil { + changedNodes = append(changedNodes, update.ChangeNodes...) } } } } - return nil -} - -func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error { - routes, err := hsdb.getNodeRoutes(node) - if err != nil { - return nil - } - - var changedKeys []key.MachinePublic - - for _, route := range routes { - changed, err := hsdb.failoverRoute(&route) - if err != nil { - return err - } - - changedKeys = append(changedKeys, changed...) - } - - changedKeys = lo.Uniq(changedKeys) - - var nodes types.Nodes - - for _, key := range changedKeys { - node, err := hsdb.GetNodeByMachineKey(key) - if err != nil { - return err - } - - nodes = append(nodes, node) - } - - if nodes != nil { - stateUpdate := types.StateUpdate{ + if len(changedNodes) != 0 { + return &types.StateUpdate{ Type: types.StatePeerChanged, - ChangeNodes: nodes, - Message: "called from db.FailoverNodeRoutesWithNotify", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } + ChangeNodes: changedNodes, + Message: "called from db.EnsureFailoverRouteIsAvailable", + }, nil } - return nil + return nil, nil } -func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error { - changedKeys, err := hsdb.failoverRoute(r) +func failoverRouteReturnUpdate( + tx *gorm.DB, + isConnected map[key.MachinePublic]bool, + r *types.Route, +) (*types.StateUpdate, error) { + changedKeys, err := failoverRoute(tx, isConnected, r) if err != nil { - return err + return nil, err } + log.Trace(). + Interface("isConnected", isConnected). + Interface("changedKeys", changedKeys). + Msg("building route failover") + if len(changedKeys) == 0 { - return nil + return nil, nil } var nodes types.Nodes - - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("loading machines with new primary routes from db") - for _, key := range changedKeys { - node, err := hsdb.getNodeByMachineKey(key) + node, err := GetNodeByMachineKey(tx, key) if err != nil { - return err + return nil, err } nodes = append(nodes, node) } - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("notifying peers about primary route change") - - if nodes != nil { - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: nodes, - Message: "called from db.failoverRouteWithNotify", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - } - - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("notified peers about primary route change") - - return nil + return &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: nodes, + Message: "called from db.failoverRouteReturnUpdate", + }, nil } // failoverRoute takes a route that is no longer available, @@ -556,12 +510,16 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error { // // and tries to find a new route to take over its place. // If the given route was not primary, it returns early. -func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) { +func failoverRoute( + tx *gorm.DB, + isConnected map[key.MachinePublic]bool, + r *types.Route, +) ([]key.MachinePublic, error) { if r == nil { return nil, nil } - // This route is not a primary route, and it isnt + // This route is not a primary route, and it is not // being served to nodes. if !r.IsPrimary { return nil, nil @@ -572,7 +530,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro return nil, nil } - routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix)) + routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix)) if err != nil { return nil, err } @@ -585,14 +543,18 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro continue } - if hsdb.notifier.IsConnected(route.Node.MachineKey) { + if !route.Enabled { + continue + } + + if isConnected[route.Node.MachineKey] { newPrimary = &routes[idx] break } } // If a new route was not found/available, - // return with an error. + // return without an error. // We do not want to update the database as // the one currently marked as primary is the // best we got. @@ -606,7 +568,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro // Remove primary from the old route r.IsPrimary = false - err = hsdb.db.Save(&r).Error + err = tx.Save(&r).Error if err != nil { log.Error().Err(err).Msg("error disabling new primary route") @@ -619,7 +581,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro // Set primary for the new primary newPrimary.IsPrimary = true - err = hsdb.db.Save(&newPrimary).Error + err = tx.Save(&newPrimary).Error if err != nil { log.Error().Err(err).Msg("error enabling new primary route") @@ -634,19 +596,26 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil } -// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. func (hsdb *HSDatabase) EnableAutoApprovedRoutes( aclPolicy *policy.ACLPolicy, node *types.Node, -) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() +) (*types.StateUpdate, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return EnableAutoApprovedRoutes(tx, aclPolicy, node) + }) +} +// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. +func EnableAutoApprovedRoutes( + tx *gorm.DB, + aclPolicy *policy.ACLPolicy, + node *types.Node, +) (*types.StateUpdate, error) { if len(node.IPAddresses) == 0 { - return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs + return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs } - routes, err := hsdb.getNodeAdvertisedRoutes(node) + routes, err := GetNodeAdvertisedRoutes(tx, node) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { log.Error(). Caller(). @@ -654,9 +623,11 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( Str("node", node.Hostname). Msg("Could not get advertised routes for node") - return err + return nil, err } + log.Trace().Interface("routes", routes).Msg("routes for autoapproving") + approvedRoutes := types.Routes{} for _, advertisedRoute := range routes { @@ -673,9 +644,16 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( Uint64("nodeId", node.ID). Msg("Failed to resolve autoApprovers for advertised route") - return err + return nil, err } + log.Trace(). + Str("node", node.Hostname). + Str("user", node.User.Name). + Strs("routeApprovers", routeApprovers). + Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()). + Msg("looking up route for autoapproving") + for _, approvedAlias := range routeApprovers { if approvedAlias == node.User.Name { approvedRoutes = append(approvedRoutes, advertisedRoute) @@ -687,7 +665,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( Str("alias", approvedAlias). Msg("Failed to expand alias when processing autoApprovers policy") - return err + return nil, err } // approvedIPs should contain all of node's IPs if it matches the rule, so check for first @@ -698,17 +676,25 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( } } + update := &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{}, + Message: "created in db.EnableAutoApprovedRoutes", + } + for _, approvedRoute := range approvedRoutes { - err := hsdb.enableRoute(uint64(approvedRoute.ID)) + perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID)) if err != nil { log.Err(err). Str("approvedRoute", approvedRoute.String()). Uint64("nodeId", node.ID). Msg("Failed to enable approved route") - return err + return nil, err } + + update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...) } - return nil + return update, nil } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index d491b6a..5d6281e 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -24,7 +24,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_get_route_node") + _, err = db.getNode("test", "test_get_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix("10.0.0.0/24") @@ -42,7 +42,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo, } - db.db.Save(&node) + db.DB.Save(&node) su, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -52,10 +52,11 @@ func (s *Suite) TestGetRoutes(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(advertisedRoutes), check.Equals, 1) - err = db.enableRoutes(&node, "192.168.0.0/24") + // TODO(kradalby): check state update + _, err = db.enableRoutes(&node, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = db.enableRoutes(&node, "10.0.0.0/24") + _, err = db.enableRoutes(&node, "10.0.0.0/24") c.Assert(err, check.IsNil) } @@ -66,7 +67,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_enable_route_node") + _, err = db.getNode("test", "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -91,7 +92,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo, } - db.db.Save(&node) + db.DB.Save(&node) sendUpdate, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -106,10 +107,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(noEnabledRoutes), check.Equals, 0) - err = db.enableRoutes(&node, "192.168.0.0/24") + _, err = db.enableRoutes(&node, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = db.enableRoutes(&node, "10.0.0.0/24") + _, err = db.enableRoutes(&node, "10.0.0.0/24") c.Assert(err, check.IsNil) enabledRoutes, err := db.GetEnabledRoutes(&node) @@ -117,14 +118,14 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { c.Assert(len(enabledRoutes), check.Equals, 1) // Adding it twice will just let it pass through - err = db.enableRoutes(&node, "10.0.0.0/24") + _, err = db.enableRoutes(&node, "10.0.0.0/24") c.Assert(err, check.IsNil) enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node) c.Assert(err, check.IsNil) c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) - err = db.enableRoutes(&node, "150.0.10.0/25") + _, err = db.enableRoutes(&node, "150.0.10.0/25") c.Assert(err, check.IsNil) enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node) @@ -139,7 +140,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_enable_route_node") + _, err = db.getNode("test", "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -163,16 +164,16 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo1, } - db.db.Save(&node1) + db.DB.Save(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) c.Assert(sendUpdate, check.Equals, false) - err = db.enableRoutes(&node1, route.String()) + _, err = db.enableRoutes(&node1, route.String()) c.Assert(err, check.IsNil) - err = db.enableRoutes(&node1, route2.String()) + _, err = db.enableRoutes(&node1, route2.String()) c.Assert(err, check.IsNil) hostInfo2 := tailcfg.Hostinfo{ @@ -186,13 +187,13 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo2, } - db.db.Save(&node2) + db.DB.Save(&node2) sendUpdate, err = db.SaveNodeRoutes(&node2) c.Assert(err, check.IsNil) c.Assert(sendUpdate, check.Equals, false) - err = db.enableRoutes(&node2, route2.String()) + _, err = db.enableRoutes(&node2, route2.String()) c.Assert(err, check.IsNil) enabledRoutes1, err := db.GetEnabledRoutes(&node1) @@ -219,7 +220,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_enable_route_node") + _, err = db.getNode("test", "test_enable_route_node") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -246,22 +247,23 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { Hostinfo: &hostInfo1, LastSeen: &now, } - db.db.Save(&node1) + db.DB.Save(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) c.Assert(sendUpdate, check.Equals, false) - err = db.enableRoutes(&node1, prefix.String()) + _, err = db.enableRoutes(&node1, prefix.String()) c.Assert(err, check.IsNil) - err = db.enableRoutes(&node1, prefix2.String()) + _, err = db.enableRoutes(&node1, prefix2.String()) c.Assert(err, check.IsNil) routes, err := db.GetNodeRoutes(&node1) c.Assert(err, check.IsNil) - err = db.DeleteRoute(uint64(routes[0].ID)) + // TODO(kradalby): check stateupdate + _, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{}) c.Assert(err, check.IsNil) enabledRoutes1, err := db.GetEnabledRoutes(&node1) @@ -269,17 +271,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { c.Assert(len(enabledRoutes1), check.Equals, 1) } +var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } + func TestFailoverRoute(t *testing.T) { - ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } - - // TODO(kradalby): Count/verify updates - var sink chan types.StateUpdate - - go func() { - for range sink { - } - }() - machineKeys := []key.MachinePublic{ key.NewMachine().Public(), key.NewMachine().Public(), @@ -291,6 +285,7 @@ func TestFailoverRoute(t *testing.T) { name string failingRoute types.Route routes types.Routes + isConnected map[key.MachinePublic]bool want []key.MachinePublic wantErr bool }{ @@ -371,6 +366,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, routes: types.Routes{ types.Route{ @@ -382,6 +378,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, types.Route{ Model: gorm.Model{ @@ -392,8 +389,13 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[1], }, IsPrimary: false, + Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: false, + machineKeys[1]: true, + }, want: []key.MachinePublic{ machineKeys[0], machineKeys[1], @@ -411,6 +413,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: false, + Enabled: true, }, routes: types.Routes{ types.Route{ @@ -422,6 +425,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, types.Route{ Model: gorm.Model{ @@ -432,6 +436,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[1], }, IsPrimary: false, + Enabled: true, }, }, want: nil, @@ -448,6 +453,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[1], }, IsPrimary: true, + Enabled: true, }, routes: types.Routes{ types.Route{ @@ -459,6 +465,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: false, + Enabled: true, }, types.Route{ Model: gorm.Model{ @@ -469,6 +476,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[1], }, IsPrimary: true, + Enabled: true, }, types.Route{ Model: gorm.Model{ @@ -479,8 +487,14 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[2], }, IsPrimary: false, + Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: true, + machineKeys[1]: true, + machineKeys[2]: true, + }, want: []key.MachinePublic{ machineKeys[1], machineKeys[0], @@ -498,6 +512,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, routes: types.Routes{ types.Route{ @@ -509,6 +524,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, // Offline types.Route{ @@ -520,8 +536,13 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[3], }, IsPrimary: false, + Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: true, + machineKeys[3]: false, + }, want: nil, wantErr: false, }, @@ -536,6 +557,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, routes: types.Routes{ types.Route{ @@ -547,6 +569,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[0], }, IsPrimary: true, + Enabled: true, }, // Offline types.Route{ @@ -558,6 +581,7 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[3], }, IsPrimary: false, + Enabled: true, }, types.Route{ Model: gorm.Model{ @@ -568,14 +592,61 @@ func TestFailoverRoute(t *testing.T) { MachineKey: machineKeys[1], }, IsPrimary: true, + Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: false, + machineKeys[1]: true, + machineKeys[3]: false, + }, want: []key.MachinePublic{ machineKeys[0], machineKeys[1], }, wantErr: false, }, + { + name: "failover-primary-none-enabled", + failingRoute: types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + Enabled: true, + }, + routes: types.Routes{ + types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + Enabled: true, + }, + // not enabled + types.Route{ + Model: gorm.Model{ + ID: 2, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[1], + }, + IsPrimary: false, + Enabled: false, + }, + }, + want: nil, + wantErr: false, + }, } for _, tt := range tests { @@ -583,13 +654,14 @@ func TestFailoverRoute(t *testing.T) { tmpDir, err := os.MkdirTemp("", "failover-db-test") assert.NoError(t, err) - notif := notifier.NewNotifier() - db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, - notif, + types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + notifier.NewNotifier(), []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, @@ -597,23 +669,15 @@ func TestFailoverRoute(t *testing.T) { ) assert.NoError(t, err) - // Pretend that all the nodes are connected to control - for idx, key := range machineKeys { - // Pretend one node is offline - if idx == 3 { - continue - } - - notif.AddNode(key, sink) - } - for _, route := range tt.routes { - if err := db.db.Save(&route).Error; err != nil { + if err := db.DB.Save(&route).Error; err != nil { t.Fatalf("failed to create route: %s", err) } } - got, err := db.failoverRoute(&tt.failingRoute) + got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) { + return failoverRoute(tx, tt.isConnected, &tt.failingRoute) + }) if (err != nil) != tt.wantErr { t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr) @@ -627,3 +691,231 @@ func TestFailoverRoute(t *testing.T) { }) } } + +// func TestDisableRouteFailover(t *testing.T) { +// machineKeys := []key.MachinePublic{ +// key.NewMachine().Public(), +// key.NewMachine().Public(), +// key.NewMachine().Public(), +// key.NewMachine().Public(), +// } + +// tests := []struct { +// name string +// nodes types.Nodes + +// routeID uint64 +// isConnected map[key.MachinePublic]bool + +// wantMachineKey key.MachinePublic +// wantErr string +// }{ +// { +// name: "single-route", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// Node: types.Node{ +// MachineKey: machineKeys[0], +// }, +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[0], +// }, +// { +// name: "failover-simple", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// &types.Node{ +// ID: 1, +// MachineKey: machineKeys[1], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 2, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: false, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[1], +// }, +// { +// name: "no-failover-offline", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// &types.Node{ +// ID: 1, +// MachineKey: machineKeys[1], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 2, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: false, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// isConnected: map[key.MachinePublic]bool{ +// machineKeys[0]: true, +// machineKeys[1]: false, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[1], +// }, +// { +// name: "failover-to-online", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// &types.Node{ +// ID: 1, +// MachineKey: machineKeys[1], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 2, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: false, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// isConnected: map[key.MachinePublic]bool{ +// machineKeys[0]: true, +// machineKeys[1]: true, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[1], +// }, +// } + +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "") +// assert.NoError(t, err) + +// // bootstrap db +// datab.DB.Transaction(func(tx *gorm.DB) error { +// for _, node := range tt.nodes { +// err := tx.Save(node).Error +// if err != nil { +// return err +// } + +// _, err = SaveNodeRoutes(tx, node) +// if err != nil { +// return err +// } +// } + +// return nil +// }) + +// got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { +// return DisableRoute(tx, tt.routeID, tt.isConnected) +// }) + +// // if (err.Error() != "") != tt.wantErr { +// // t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr) + +// // return +// // } + +// if len(got.ChangeNodes) != 1 { +// t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes)) +// } + +// if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" { +// t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff) +// } +// }) +// } +// } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 1c38491..e176e4b 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/types" "gopkg.in/check.v1" ) @@ -45,9 +46,12 @@ func (s *Suite) ResetDB(c *check.C) { log.Printf("database path: %s", tmpDir+"/headscale_test.db") db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, + types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, notifier.NewNotifier(), []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 27a1406..99e9339 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -15,22 +15,25 @@ var ( ErrUserStillHasNodes = errors.New("user not empty: node(s) found") ) +func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.User, error) { + return CreateUser(tx, name) + }) +} + // CreateUser creates a new User. Returns error if could not be created // or another user already exists. -func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func CreateUser(tx *gorm.DB, name string) (*types.User, error) { err := util.CheckForFQDNRules(name) if err != nil { return nil, err } user := types.User{} - if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { + if err := tx.Where("name = ?", name).First(&user).Error; err == nil { return nil, ErrUserExists } user.Name = name - if err := hsdb.db.Create(&user).Error; err != nil { + if err := tx.Create(&user).Error; err != nil { log.Error(). Str("func", "CreateUser"). Err(err). @@ -42,18 +45,21 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { return &user, nil } +func (hsdb *HSDatabase) DestroyUser(name string) error { + return hsdb.Write(func(tx *gorm.DB) error { + return DestroyUser(tx, name) + }) +} + // DestroyUser destroys a User. Returns error if the User does // not exist or if there are nodes associated with it. -func (hsdb *HSDatabase) DestroyUser(name string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - user, err := hsdb.getUser(name) +func DestroyUser(tx *gorm.DB, name string) error { + user, err := GetUser(tx, name) if err != nil { return ErrUserNotFound } - nodes, err := hsdb.listNodesByUser(name) + nodes, err := ListNodesByUser(tx, name) if err != nil { return err } @@ -61,32 +67,35 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { return ErrUserStillHasNodes } - keys, err := hsdb.listPreAuthKeys(name) + keys, err := ListPreAuthKeys(tx, name) if err != nil { return err } for _, key := range keys { - err = hsdb.destroyPreAuthKey(key) + err = DestroyPreAuthKey(tx, key) if err != nil { return err } } - if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil { + if result := tx.Unscoped().Delete(&user); result.Error != nil { return result.Error } return nil } +func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { + return hsdb.Write(func(tx *gorm.DB) error { + return RenameUser(tx, oldName, newName) + }) +} + // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. -func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func RenameUser(tx *gorm.DB, oldName, newName string) error { var err error - oldUser, err := hsdb.getUser(oldName) + oldUser, err := GetUser(tx, oldName) if err != nil { return err } @@ -94,7 +103,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { if err != nil { return err } - _, err = hsdb.getUser(newName) + _, err = GetUser(tx, newName) if err == nil { return ErrUserExists } @@ -104,24 +113,22 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { oldUser.Name = newName - if result := hsdb.db.Save(&oldUser); result.Error != nil { + if result := tx.Save(&oldUser); result.Error != nil { return result.Error } return nil } -// GetUser fetches a user by name. func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getUser(name) + return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { + return GetUser(rx, name) + }) } -func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { +func GetUser(tx *gorm.DB, name string) (*types.User, error) { user := types.User{} - if result := hsdb.db.First(&user, "name = ?", name); errors.Is( + if result := tx.First(&user, "name = ?", name); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -131,17 +138,16 @@ func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { return &user, nil } -// ListUsers gets all the existing users. func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listUsers() + return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) { + return ListUsers(rx) + }) } -func (hsdb *HSDatabase) listUsers() ([]types.User, error) { +// ListUsers gets all the existing users. +func ListUsers(tx *gorm.DB) ([]types.User, error) { users := []types.User{} - if err := hsdb.db.Find(&users).Error; err != nil { + if err := tx.Find(&users).Error; err != nil { return nil, err } @@ -149,46 +155,42 @@ func (hsdb *HSDatabase) listUsers() ([]types.User, error) { } // ListNodesByUser gets all the nodes in a given user. -func (hsdb *HSDatabase) ListNodesByUser(name string) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listNodesByUser(name) -} - -func (hsdb *HSDatabase) listNodesByUser(name string) (types.Nodes, error) { +func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) { err := util.CheckForFQDNRules(name) if err != nil { return nil, err } - user, err := hsdb.getUser(name) + user, err := GetUser(tx, name) if err != nil { return nil, err } nodes := types.Nodes{} - if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { + if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } -// AssignNodeToUser assigns a Node to a user. func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + return hsdb.Write(func(tx *gorm.DB) error { + return AssignNodeToUser(tx, node, username) + }) +} +// AssignNodeToUser assigns a Node to a user. +func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error { err := util.CheckForFQDNRules(username) if err != nil { return err } - user, err := hsdb.getUser(username) + user, err := GetUser(tx, username) if err != nil { return err } node.User = *user - if result := hsdb.db.Save(&node); result.Error != nil { + if result := tx.Save(&node); result.Error != nil { return result.Error } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 1ca3b49..b36e861 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -36,7 +36,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { err = db.DestroyUser("test") c.Assert(err, check.IsNil) - result := db.db.Preload("User").First(&pak, "key = ?", pak.Key) + result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key) // destroying a user also deletes all associated preauthkeys c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) @@ -53,7 +53,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) err = db.DestroyUser("test") c.Assert(err, check.Equals, ErrUserStillHasNodes) @@ -105,7 +105,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) c.Assert(node.UserID, check.Equals, oldUser.ID) err = db.AssignNodeToUser(&node, newUser.Name) diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index 59e4028..52a63e9 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -84,6 +84,8 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { RegionID: d.cfg.ServerRegionID, HostName: host, DERPPort: port, + IPv4: d.cfg.IPv4, + IPv6: d.cfg.IPv6, }, }, } @@ -99,6 +101,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { localDERPregion.Nodes[0].STUNPort = portSTUN log.Info().Caller().Msgf("DERP region: %+v", localDERPregion) + log.Info().Caller().Msgf("DERP Nodes[0]: %+v", localDERPregion.Nodes[0]) return localDERPregion, nil } @@ -208,6 +211,7 @@ func DERPProbeHandler( // The initial implementation is here https://github.com/tailscale/tailscale/pull/1406 // They have a cache, but not clear if that is really necessary at Headscale, uh, scale. // An example implementation is found here https://derp.tailscale.com/bootstrap-dns +// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL. func DERPBootstrapDNSHandler( derpMap *tailcfg.DERPMap, ) func(http.ResponseWriter, *http.Request) { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index ffd3a57..c12ba73 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -8,11 +8,13 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -136,12 +138,14 @@ func (api headscaleV1APIServer) ExpirePreAuthKey( ctx context.Context, request *v1.ExpirePreAuthKeyRequest, ) (*v1.ExpirePreAuthKeyResponse, error) { - preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key) - if err != nil { - return nil, err - } + err := api.h.db.DB.Transaction(func(tx *gorm.DB) error { + preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key) + if err != nil { + return err + } - err = api.h.db.ExpirePreAuthKey(preAuthKey) + return db.ExpirePreAuthKey(tx, preAuthKey) + }) if err != nil { return nil, err } @@ -181,17 +185,31 @@ func (api headscaleV1APIServer) RegisterNode( return nil, err } - node, err := api.h.db.RegisterNodeFromAuthCallback( - api.h.registrationCache, - mkey, - request.GetUser(), - nil, - util.RegisterMethodCLI, - ) + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + return db.RegisterNodeFromAuthCallback( + tx, + api.h.registrationCache, + mkey, + request.GetUser(), + nil, + util.RegisterMethodCLI, + api.h.cfg.IPPrefixes, + ) + }) if err != nil { return nil, err } + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from api.RegisterNode", + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-registernode", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) + } + return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } @@ -217,25 +235,35 @@ func (api headscaleV1APIServer) SetTags( ctx context.Context, request *v1.SetTagsRequest, ) (*v1.SetTagsResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) - if err != nil { - return nil, err - } - for _, tag := range request.GetTags() { err := validateTag(tag) if err != nil { - return &v1.SetTagsResponse{ - Node: nil, - }, status.Error(codes.InvalidArgument, err.Error()) + return nil, err } } - err = api.h.db.SetTags(node, request.GetTags()) + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + err := db.SetTags(tx, request.GetNodeId(), request.GetTags()) + if err != nil { + return nil, err + } + + return db.GetNodeByID(tx, request.GetNodeId()) + }) if err != nil { return &v1.SetTagsResponse{ Node: nil, - }, status.Error(codes.Internal, err.Error()) + }, status.Error(codes.InvalidArgument, err.Error()) + } + + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from api.SetTags", + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) } log.Trace(). @@ -270,11 +298,21 @@ func (api headscaleV1APIServer) DeleteNode( err = api.h.db.DeleteNode( node, + api.h.nodeNotifier.ConnectedMap(), ) if err != nil { return nil, err } + stateUpdate := types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname) + api.h.nodeNotifier.NotifyAll(ctx, stateUpdate) + } + return &v1.DeleteNodeResponse{}, nil } @@ -282,17 +320,38 @@ func (api headscaleV1APIServer) ExpireNode( ctx context.Context, request *v1.ExpireNodeRequest, ) (*v1.ExpireNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + now := time.Now() + + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + db.NodeSetExpiry( + tx, + request.GetNodeId(), + now, + ) + + return db.GetNodeByID(tx, request.GetNodeId()) + }) if err != nil { return nil, err } - now := time.Now() + selfUpdate := types.StateUpdate{ + Type: types.StateSelfUpdate, + ChangeNodes: types.Nodes{node}, + } + if selfUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname) + api.h.nodeNotifier.NotifyByMachineKey( + ctx, + selfUpdate, + node.MachineKey) + } - api.h.db.NodeSetExpiry( - node, - now, - ) + stateUpdate := types.StateUpdateExpire(node.ID, now) + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) + } log.Trace(). Str("node", node.Hostname). @@ -306,17 +365,30 @@ func (api headscaleV1APIServer) RenameNode( ctx context.Context, request *v1.RenameNodeRequest, ) (*v1.RenameNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + err := db.RenameNode( + tx, + request.GetNodeId(), + request.GetNewName(), + ) + if err != nil { + return nil, err + } + + return db.GetNodeByID(tx, request.GetNodeId()) + }) if err != nil { return nil, err } - err = api.h.db.RenameNode( - node, - request.GetNewName(), - ) - if err != nil { - return nil, err + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from api.RenameNode", + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) } log.Trace(). @@ -331,8 +403,11 @@ func (api headscaleV1APIServer) ListNodes( ctx context.Context, request *v1.ListNodesRequest, ) (*v1.ListNodesResponse, error) { + isConnected := api.h.nodeNotifier.ConnectedMap() if request.GetUser() != "" { - nodes, err := api.h.db.ListNodesByUser(request.GetUser()) + nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) { + return db.ListNodesByUser(rx, request.GetUser()) + }) if err != nil { return nil, err } @@ -343,7 +418,7 @@ func (api headscaleV1APIServer) ListNodes( // Populate the online field based on // currently connected nodes. - resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) + resp.Online = isConnected[node.MachineKey] response[index] = resp } @@ -362,10 +437,10 @@ func (api headscaleV1APIServer) ListNodes( // Populate the online field based on // currently connected nodes. - resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) + resp.Online = isConnected[node.MachineKey] validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( - &node, + node, ) resp.InvalidTags = invalidTags resp.ValidTags = validTags @@ -396,7 +471,9 @@ func (api headscaleV1APIServer) GetRoutes( ctx context.Context, request *v1.GetRoutesRequest, ) (*v1.GetRoutesResponse, error) { - routes, err := api.h.db.GetRoutes() + routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) { + return db.GetRoutes(rx) + }) if err != nil { return nil, err } @@ -410,11 +487,19 @@ func (api headscaleV1APIServer) EnableRoute( ctx context.Context, request *v1.EnableRouteRequest, ) (*v1.EnableRouteResponse, error) { - err := api.h.db.EnableRoute(request.GetRouteId()) + update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.EnableRoute(tx, request.GetRouteId()) + }) if err != nil { return nil, err } + if update != nil && update.Valid() { + ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown") + api.h.nodeNotifier.NotifyAll( + ctx, *update) + } + return &v1.EnableRouteResponse{}, nil } @@ -422,11 +507,19 @@ func (api headscaleV1APIServer) DisableRoute( ctx context.Context, request *v1.DisableRouteRequest, ) (*v1.DisableRouteResponse, error) { - err := api.h.db.DisableRoute(request.GetRouteId()) + isConnected := api.h.nodeNotifier.ConnectedMap() + update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.DisableRoute(tx, request.GetRouteId(), isConnected) + }) if err != nil { return nil, err } + if update != nil && update.Valid() { + ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown") + api.h.nodeNotifier.NotifyAll(ctx, *update) + } + return &v1.DisableRouteResponse{}, nil } @@ -453,11 +546,19 @@ func (api headscaleV1APIServer) DeleteRoute( ctx context.Context, request *v1.DeleteRouteRequest, ) (*v1.DeleteRouteResponse, error) { - err := api.h.db.DeleteRoute(request.GetRouteId()) + isConnected := api.h.nodeNotifier.ConnectedMap() + update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.DeleteRoute(tx, request.GetRouteId(), isConnected) + }) if err != nil { return nil, err } + if update != nil && update.Valid() { + ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown") + api.h.nodeNotifier.NotifyWithIgnore(ctx, *update) + } + return &v1.DeleteRouteResponse{}, nil } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index d6404ce..df0f4d9 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -272,13 +272,26 @@ func (m *Mapper) LiteMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, pol *policy.ACLPolicy, + messages ...string, ) ([]byte, error) { resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version) if err != nil { return nil, err } - return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress) + rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( + pol, + node, + nodeMapToList(m.peers), + ) + if err != nil { + return nil, err + } + + resp.PacketFilter = policy.ReduceFilterRules(node, rules) + resp.SSHPolicy = sshPolicy + + return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...) } func (m *Mapper) KeepAliveResponse( @@ -380,9 +393,7 @@ func (m *Mapper) PeerChangedPatchResponse( } if patches, ok := m.patches[uint64(change.NodeID)]; ok { - patches := append(patches, p) - - m.patches[uint64(change.NodeID)] = patches + m.patches[uint64(change.NodeID)] = append(patches, p) } else { m.patches[uint64(change.NodeID)] = []patch{p} } @@ -458,6 +469,8 @@ func (m *Mapper) marshalMapResponse( switch { case resp.Peers != nil && len(resp.Peers) > 0: responseType = "full" + case isSelfUpdate(messages...): + responseType = "self" case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil: responseType = "lite" case resp.PeersChanged != nil && len(resp.PeersChanged) > 0: @@ -656,3 +669,13 @@ func appendPeerChanges( return nil } + +func isSelfUpdate(messages ...string) bool { + for _, message := range messages { + if strings.Contains(message, types.SelfUpdateIdentifier) { + return true + } + } + + return false +} diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index e213a95..c10da4d 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -72,7 +72,7 @@ func tailNode( } var derp string - if node.Hostinfo.NetInfo != nil { + if node.Hostinfo != nil && node.Hostinfo.NetInfo != nil { derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP) } else { derp = "127.3.3.40:0" // Zero means disconnected or unknown. diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index 77e8b19..2384a40 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -1,6 +1,7 @@ package notifier import ( + "context" "fmt" "strings" "sync" @@ -12,26 +13,30 @@ import ( ) type Notifier struct { - l sync.RWMutex - nodes map[string]chan<- types.StateUpdate + l sync.RWMutex + nodes map[string]chan<- types.StateUpdate + connected map[key.MachinePublic]bool } func NewNotifier() *Notifier { - return &Notifier{} + return &Notifier{ + nodes: make(map[string]chan<- types.StateUpdate), + connected: make(map[key.MachinePublic]bool), + } } func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) { log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node") - defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to add node") + defer log.Trace(). + Caller(). + Str("key", machineKey.ShortString()). + Msg("releasing lock to add node") n.l.Lock() defer n.l.Unlock() - if n.nodes == nil { - n.nodes = make(map[string]chan<- types.StateUpdate) - } - n.nodes[machineKey.String()] = c + n.connected[machineKey] = true log.Trace(). Str("machine_key", machineKey.ShortString()). @@ -41,16 +46,20 @@ func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpd func (n *Notifier) RemoveNode(machineKey key.MachinePublic) { log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node") - defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to remove node") + defer log.Trace(). + Caller(). + Str("key", machineKey.ShortString()). + Msg("releasing lock to remove node") n.l.Lock() defer n.l.Unlock() - if n.nodes == nil { + if len(n.nodes) == 0 { return } delete(n.nodes, machineKey.String()) + n.connected[machineKey] = false log.Trace(). Str("machine_key", machineKey.ShortString()). @@ -64,23 +73,28 @@ func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool { n.l.RLock() defer n.l.RUnlock() - if _, ok := n.nodes[machineKey.String()]; ok { - return true - } - - return false + return n.connected[machineKey] } -func (n *Notifier) NotifyAll(update types.StateUpdate) { - n.NotifyWithIgnore(update) +// TODO(kradalby): This returns a pointer and can be dangerous. +func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool { + return n.connected } -func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) { +func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) { + n.NotifyWithIgnore(ctx, update) +} + +func (n *Notifier) NotifyWithIgnore( + ctx context.Context, + update types.StateUpdate, + ignore ...string, +) { log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") defer log.Trace(). Caller(). Interface("type", update.Type). - Msg("releasing lock, finished notifing") + Msg("releasing lock, finished notifying") n.l.RLock() defer n.l.RUnlock() @@ -90,23 +104,58 @@ func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) continue } - log.Trace().Caller().Str("machine", key).Strs("ignoring", ignore).Msg("sending update") - c <- update + select { + case <-ctx.Done(): + log.Error(). + Err(ctx.Err()). + Str("mkey", key). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update not sent, context cancelled") + + return + case c <- update: + log.Trace(). + Str("mkey", key). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update successfully sent on chan") + } } } -func (n *Notifier) NotifyByMachineKey(update types.StateUpdate, mKey key.MachinePublic) { +func (n *Notifier) NotifyByMachineKey( + ctx context.Context, + update types.StateUpdate, + mKey key.MachinePublic, +) { log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") defer log.Trace(). Caller(). Interface("type", update.Type). - Msg("releasing lock, finished notifing") + Msg("releasing lock, finished notifying") n.l.RLock() defer n.l.RUnlock() if c, ok := n.nodes[mKey.String()]; ok { - c <- update + select { + case <-ctx.Done(): + log.Error(). + Err(ctx.Err()). + Str("mkey", mKey.String()). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update not sent, context cancelled") + + return + case c <- update: + log.Trace(). + Str("mkey", mKey.String()). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update successfully sent on chan") + } } } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index c678693..53a0553 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -21,6 +21,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" + "gorm.io/gorm" "tailscale.com/types/key" ) @@ -569,7 +570,7 @@ func (h *Headscale) validateNodeForOIDCCallback( Str("node", node.Hostname). Msg("node already registered, reauthenticating") - err := h.db.NodeSetExpiry(node, expiry) + err := h.db.NodeSetExpiry(node.ID, expiry) if err != nil { util.LogErr(err, "Failed to refresh node") http.Error( @@ -623,6 +624,12 @@ func (h *Headscale) validateNodeForOIDCCallback( util.LogErr(err, "Failed to write response") } + stateUpdate := types.StateUpdateExpire(node.ID, expiry) + if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na") + h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) + } + return nil, true, nil } @@ -712,14 +719,22 @@ func (h *Headscale) registerNodeForOIDCCallback( machineKey *key.MachinePublic, expiry time.Time, ) error { - if _, err := h.db.RegisterNodeFromAuthCallback( - // TODO(kradalby): find a better way to use the cache across modules - h.registrationCache, - *machineKey, - user.Name, - &expiry, - util.RegisterMethodOIDC, - ); err != nil { + if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + if _, err := db.RegisterNodeFromAuthCallback( + // TODO(kradalby): find a better way to use the cache across modules + tx, + h.registrationCache, + *machineKey, + user.Name, + &expiry, + util.RegisterMethodOIDC, + h.cfg.IPPrefixes, + ); err != nil { + return err + } + + return nil + }); err != nil { util.LogErr(err, "could not register node") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 4798d81..2ccc56b 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -250,6 +250,21 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F if node.IPAddresses.InIPSet(expanded) { dests = append(dests, dest) } + + // If the node exposes routes, ensure they are note removed + // when the filters are reduced. + if node.Hostinfo != nil { + // TODO(kradalby): Evaluate if we should only keep + // the routes if the route is enabled. This will + // require database access in this part of the code. + if len(node.Hostinfo.RoutableIPs) > 0 { + for _, routableIP := range node.Hostinfo.RoutableIPs { + if expanded.ContainsPrefix(routableIP) { + dests = append(dests, dest) + } + } + } + } } if len(dests) > 0 { @@ -890,32 +905,39 @@ func (pol *ACLPolicy) TagsOfNode( validTags := make([]string, 0) invalidTags := make([]string, 0) + // TODO(kradalby): Why is this sometimes nil? coming from tailNode? + if node == nil { + return validTags, invalidTags + } + validTagMap := make(map[string]bool) invalidTagMap := make(map[string]bool) - for _, tag := range node.Hostinfo.RequestTags { - owners, err := expandOwnersFromTag(pol, tag) - if errors.Is(err, ErrInvalidTag) { - invalidTagMap[tag] = true + if node.Hostinfo != nil { + for _, tag := range node.Hostinfo.RequestTags { + owners, err := expandOwnersFromTag(pol, tag) + if errors.Is(err, ErrInvalidTag) { + invalidTagMap[tag] = true - continue - } - var found bool - for _, owner := range owners { - if node.User.Name == owner { - found = true + continue + } + var found bool + for _, owner := range owners { + if node.User.Name == owner { + found = true + } + } + if found { + validTagMap[tag] = true + } else { + invalidTagMap[tag] = true } } - if found { - validTagMap[tag] = true - } else { - invalidTagMap[tag] = true + for tag := range invalidTagMap { + invalidTags = append(invalidTags, tag) + } + for tag := range validTagMap { + validTags = append(validTags, tag) } - } - for tag := range invalidTagMap { - invalidTags = append(invalidTags, tag) - } - for tag := range validTagMap { - validTags = append(validTags, tag) } return validTags, invalidTags diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index c048778..4a74bda 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -1901,6 +1901,81 @@ func TestReduceFilterRules(t *testing.T) { }, want: []tailcfg.FilterRule{}, }, + { + name: "1604-subnet-routers-are-preserved", + pol: ACLPolicy{ + Groups: Groups{ + "group:admins": {"user1"}, + }, + ACLs: []ACL{ + { + Action: "accept", + Sources: []string{"group:admins"}, + Destinations: []string{"group:admins:*"}, + }, + { + Action: "accept", + Sources: []string{"group:admins"}, + Destinations: []string{"10.33.0.0/16:*"}, + }, + }, + }, + node: &types.Node{ + IPAddresses: types.NodeAddresses{ + netip.MustParseAddr("100.64.0.1"), + netip.MustParseAddr("fd7a:115c:a1e0::1"), + }, + User: types.User{Name: "user1"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{ + netip.MustParsePrefix("10.33.0.0/16"), + }, + }, + }, + peers: types.Nodes{ + &types.Node{ + IPAddresses: types.NodeAddresses{ + netip.MustParseAddr("100.64.0.2"), + netip.MustParseAddr("fd7a:115c:a1e0::2"), + }, + User: types.User{Name: "user1"}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.1/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::1/128", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "10.33.0.0/16", + Ports: tailcfg.PortRangeAny, + }, + }, + }, + }, + }, } for _, tt := range tests { diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 568f209..03f52ed 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -4,12 +4,15 @@ import ( "context" "fmt" "net/http" + "strings" "time" + "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" xslices "golang.org/x/exp/slices" + "gorm.io/gorm" "tailscale.com/tailcfg" ) @@ -125,6 +128,18 @@ func (h *Headscale) handlePoll( return } + + if h.ACLPolicy != nil { + // update routes with peer information + update, err := h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) + if err != nil { + logErr(err, "Error running auto approved routes") + } + + if update != nil { + sendUpdate = true + } + } } // Services is mostly useful for discovery and not critical, @@ -138,41 +153,68 @@ func (h *Headscale) handlePoll( } if sendUpdate { - if err := h.db.NodeSave(node); err != nil { + if err := h.db.DB.Save(node).Error; err != nil { logErr(err, "Failed to persist/update node in the database") http.Error(writer, "", http.StatusInternalServerError) return } + // Send an update to all peers to propagate the new routes + // available. stateUpdate := types.StateUpdate{ Type: types.StatePeerChanged, ChangeNodes: types.Nodes{node}, Message: "called from handlePoll -> update -> new hostinfo", } if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-hostinfochange", node.Hostname) h.nodeNotifier.NotifyWithIgnore( + ctx, stateUpdate, node.MachineKey.String()) } + // Send an update to the node itself with to ensure it + // has an updated packetfilter allowing the new route + // if it is defined in the ACL. + selfUpdate := types.StateUpdate{ + Type: types.StateSelfUpdate, + ChangeNodes: types.Nodes{node}, + } + if selfUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", node.Hostname) + h.nodeNotifier.NotifyByMachineKey( + ctx, + selfUpdate, + node.MachineKey) + } + return } } - if err := h.db.NodeSave(node); err != nil { + if err := h.db.DB.Save(node).Error; err != nil { logErr(err, "Failed to persist/update node in the database") http.Error(writer, "", http.StatusInternalServerError) return } + // TODO(kradalby): Figure out why patch changes does + // not show up in output from `tailscale debug netmap`. + // stateUpdate := types.StateUpdate{ + // Type: types.StatePeerChangedPatch, + // ChangePatches: []*tailcfg.PeerChange{&change}, + // } stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{&change}, + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, } if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname) h.nodeNotifier.NotifyWithIgnore( + ctx, stateUpdate, node.MachineKey.String()) } @@ -228,7 +270,7 @@ func (h *Headscale) handlePoll( } } - if err := h.db.NodeSave(node); err != nil { + if err := h.db.DB.Save(node).Error; err != nil { logErr(err, "Failed to persist/update node in the database") http.Error(writer, "", http.StatusInternalServerError) @@ -265,7 +307,10 @@ func (h *Headscale) handlePoll( // update ACLRules with peer informations (to update server tags if necessary) if h.ACLPolicy != nil { // update routes with peer information - err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) + // This state update is ignored as it will be sent + // as part of the whole node + // TODO(kradalby): figure out if that is actually correct + _, err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) if err != nil { logErr(err, "Error running auto approved routes") } @@ -301,11 +346,17 @@ func (h *Headscale) handlePoll( Message: "called from handlePoll -> new node added", } if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-newnode-peers", node.Hostname) h.nodeNotifier.NotifyWithIgnore( + ctx, stateUpdate, node.MachineKey.String()) } + if len(node.Routes) > 0 { + go h.pollFailoverRoutes(logErr, "new node", node) + } + // Set up the client stream h.pollNetMapStreamWG.Add(1) defer h.pollNetMapStreamWG.Done() @@ -323,15 +374,9 @@ func (h *Headscale) handlePoll( keepAliveTicker := time.NewTicker(keepAliveInterval) - ctx = context.WithValue(ctx, nodeNameContextKey, node.Hostname) - - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(context.WithValue(ctx, nodeNameContextKey, node.Hostname)) defer cancel() - if len(node.Routes) > 0 { - go h.db.EnsureFailoverRouteIsAvailable(node) - } - for { logInfo("Waiting for update on stream channel") select { @@ -370,6 +415,17 @@ func (h *Headscale) handlePoll( var data []byte var err error + // Ensure the node object is updated, for example, there + // might have been a hostinfo update in a sidechannel + // which contains data needed to generate a map response. + node, err = h.db.GetNodeByMachineKey(node.MachineKey) + if err != nil { + logErr(err, "Could not get machine from db") + + return + } + + startMapResp := time.Now() switch update.Type { case types.StateFullUpdate: logInfo("Sending Full MapResponse") @@ -378,6 +434,7 @@ func (h *Headscale) handlePoll( case types.StatePeerChanged: logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message)) + isConnectedMap := h.nodeNotifier.ConnectedMap() for _, node := range update.ChangeNodes { // If a node is not reported to be online, it might be // because the value is outdated, check with the notifier. @@ -385,7 +442,7 @@ func (h *Headscale) handlePoll( // this might be because it has announced itself, but not // reached the stage to actually create the notifier channel. if node.IsOnline != nil && !*node.IsOnline { - isOnline := h.nodeNotifier.IsConnected(node.MachineKey) + isOnline := isConnectedMap[node.MachineKey] node.IsOnline = &isOnline } } @@ -401,7 +458,7 @@ func (h *Headscale) handlePoll( if len(update.ChangeNodes) == 1 { logInfo("Sending SelfUpdate MapResponse") node = update.ChangeNodes[0] - data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy) + data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy, types.SelfUpdateIdentifier) } else { logInfo("SelfUpdate contained too many nodes, this is likely a bug in the code, please report.") } @@ -416,8 +473,11 @@ func (h *Headscale) handlePoll( return } + log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startMapResp).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished making map response") + // Only send update if there is change if data != nil { + startWrite := time.Now() _, err = writer.Write(data) if err != nil { logErr(err, "Could not write the map response") @@ -435,6 +495,7 @@ func (h *Headscale) handlePoll( return } + log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished writing mapresp to node") log.Info(). Caller(). @@ -454,7 +515,7 @@ func (h *Headscale) handlePoll( go h.updateNodeOnlineStatus(false, node) // Failover the node's routes if any. - go h.db.FailoverNodeRoutesWithNotify(node) + go h.pollFailoverRoutes(logErr, "node closing connection", node) // The connection has been closed, so we can stop polling. return @@ -467,6 +528,22 @@ func (h *Headscale) handlePoll( } } +func (h *Headscale) pollFailoverRoutes(logErr func(error, string), where string, node *types.Node) { + update, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.EnsureFailoverRouteIsAvailable(tx, h.nodeNotifier.ConnectedMap(), node) + }) + if err != nil { + logErr(err, fmt.Sprintf("failed to ensure failover routes, %s", where)) + + return + } + + if update != nil && !update.Empty() && update.Valid() { + ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname) + h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.MachineKey.String()) + } +} + // updateNodeOnlineStatus records the last seen status of a node and notifies peers // about change in their online/offline status. // It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged. @@ -486,10 +563,13 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) { }, } if statusUpdate.Valid() { - h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String()) + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname) + h.nodeNotifier.NotifyWithIgnore(ctx, statusUpdate, node.MachineKey.String()) } - err := h.db.UpdateLastSeen(node) + err := h.db.DB.Transaction(func(tx *gorm.DB) error { + return db.UpdateLastSeen(tx, node.ID, *node.LastSeen) + }) if err != nil { log.Error().Err(err).Msg("Cannot update node LastSeen") diff --git a/hscontrol/poll_noise.go b/hscontrol/poll_noise.go index 675836a..53b1d47 100644 --- a/hscontrol/poll_noise.go +++ b/hscontrol/poll_noise.go @@ -13,7 +13,7 @@ import ( ) const ( - MinimumCapVersion tailcfg.CapabilityVersion = 56 + MinimumCapVersion tailcfg.CapabilityVersion = 58 ) // NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol diff --git a/hscontrol/suite_test.go b/hscontrol/suite_test.go index 82bdc79..3f0cc42 100644 --- a/hscontrol/suite_test.go +++ b/hscontrol/suite_test.go @@ -41,8 +41,12 @@ func (s *Suite) ResetDB(c *check.C) { } cfg := types.Config{ NoisePrivateKeyPath: tmpDir + "/noise_private.key", - DBtype: "sqlite3", - DBpath: tmpDir + "/headscale_test.db", + Database: types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, IPPrefixes: []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index e38d8e3..ceeceea 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -1,15 +1,23 @@ package types import ( + "context" "database/sql/driver" "encoding/json" "errors" "fmt" "net/netip" + "time" "tailscale.com/tailcfg" ) +const ( + SelfUpdateIdentifier = "self-update" + DatabasePostgres = "postgres" + DatabaseSqlite = "sqlite3" +) + var ErrCannotParsePrefix = errors.New("cannot parse prefix") type IPPrefix netip.Prefix @@ -150,7 +158,9 @@ func (su *StateUpdate) Valid() bool { } case StateSelfUpdate: if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 { - panic("Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node") + panic( + "Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node", + ) } case StateDERPUpdated: if su.DERPMap == nil { @@ -160,3 +170,37 @@ func (su *StateUpdate) Valid() bool { return true } + +// Empty reports if there are any updates in the StateUpdate. +func (su *StateUpdate) Empty() bool { + switch su.Type { + case StatePeerChanged: + return len(su.ChangeNodes) == 0 + case StatePeerChangedPatch: + return len(su.ChangePatches) == 0 + case StatePeerRemoved: + return len(su.Removed) == 0 + } + + return false +} + +func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate { + return StateUpdate{ + Type: StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: tailcfg.NodeID(nodeID), + KeyExpiry: &expiry, + }, + }, + } +} + +func NotifyCtx(ctx context.Context, origin, hostname string) context.Context { + ctx2, _ := context.WithTimeout( + context.WithValue(context.WithValue(ctx, "hostname", hostname), "origin", origin), + 3*time.Second, + ) + return ctx2 +} diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 01cb9fd..89bd07e 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -11,7 +11,6 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -20,6 +19,8 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + + "github.com/juanfont/headscale/hscontrol/util" ) const ( @@ -46,16 +47,9 @@ type Config struct { Log LogConfig DisableUpdateCheck bool - DERP DERPConfig + Database DatabaseConfig - DBtype string - DBpath string - DBhost string - DBport int - DBname string - DBuser string - DBpass string - DBssl string + DERP DERPConfig TLS TLSConfig @@ -77,6 +71,31 @@ type Config struct { ACL ACLConfig } +type SqliteConfig struct { + Path string +} + +type PostgresConfig struct { + Host string + Port int + Name string + User string + Pass string + Ssl string + MaxOpenConnections int + MaxIdleConnections int + ConnMaxIdleTimeSecs int +} + +type DatabaseConfig struct { + // Type sets the database type, either "sqlite3" or "postgres" + Type string + Debug bool + + Sqlite SqliteConfig + Postgres PostgresConfig +} + type TLSConfig struct { CertPath string KeyPath string @@ -111,16 +130,19 @@ type OIDCConfig struct { } type DERPConfig struct { - ServerEnabled bool - ServerRegionID int - ServerRegionCode string - ServerRegionName string - ServerPrivateKeyPath string - STUNAddr string - URLs []url.URL - Paths []string - AutoUpdate bool - UpdateFrequency time.Duration + ServerEnabled bool + AutomaticallyAddEmbeddedDerpRegion bool + ServerRegionID int + ServerRegionCode string + ServerRegionName string + ServerPrivateKeyPath string + STUNAddr string + URLs []url.URL + Paths []string + AutoUpdate bool + UpdateFrequency time.Duration + IPv4 string + IPv6 string } type LogTailConfig struct { @@ -162,6 +184,19 @@ func LoadConfig(path string, isFile bool) error { viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) viper.AutomaticEnv() + viper.RegisterAlias("db_type", "database.type") + + // SQLite aliases + viper.RegisterAlias("db_path", "database.sqlite.path") + + // Postgres aliases + viper.RegisterAlias("db_host", "database.postgres.host") + viper.RegisterAlias("db_port", "database.postgres.port") + viper.RegisterAlias("db_name", "database.postgres.name") + viper.RegisterAlias("db_user", "database.postgres.user") + viper.RegisterAlias("db_pass", "database.postgres.pass") + viper.RegisterAlias("db_ssl", "database.postgres.ssl") + viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType) @@ -173,6 +208,7 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("derp.server.enabled", false) viper.SetDefault("derp.server.stun.enabled", true) + viper.SetDefault("derp.server.automatically_add_embedded_derp_region", true) viper.SetDefault("unix_socket", "/var/run/headscale/headscale.sock") viper.SetDefault("unix_socket_permission", "0o770") @@ -184,6 +220,10 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("cli.insecure", false) viper.SetDefault("db_ssl", false) + viper.SetDefault("database.postgres.ssl", false) + viper.SetDefault("database.postgres.max_open_conns", 10) + viper.SetDefault("database.postgres.max_idle_conns", 10) + viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600) viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) viper.SetDefault("oidc.strip_email_domain", true) @@ -262,7 +302,7 @@ func LoadConfig(path string, isFile bool) error { } if errorText != "" { - //nolint + // nolint return errors.New(strings.TrimSuffix(errorText, "\n")) } else { return nil @@ -294,8 +334,14 @@ func GetDERPConfig() DERPConfig { serverRegionCode := viper.GetString("derp.server.region_code") serverRegionName := viper.GetString("derp.server.region_name") stunAddr := viper.GetString("derp.server.stun_listen_addr") - privateKeyPath := util.AbsolutePathFromConfigPath(viper.GetString("derp.server.private_key_path")) - + privateKeyPath := util.AbsolutePathFromConfigPath( + viper.GetString("derp.server.private_key_path"), + ) + ipv4 := viper.GetString("derp.server.ipv4") + ipv6 := viper.GetString("derp.server.ipv6") + automaticallyAddEmbeddedDerpRegion := viper.GetBool( + "derp.server.automatically_add_embedded_derp_region", + ) if serverEnabled && stunAddr == "" { log.Fatal(). Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true") @@ -318,20 +364,28 @@ func GetDERPConfig() DERPConfig { paths := viper.GetStringSlice("derp.paths") + if serverEnabled && !automaticallyAddEmbeddedDerpRegion && len(paths) == 0 { + log.Fatal(). + Msg("Disabling derp.server.automatically_add_embedded_derp_region requires to configure the derp server in derp.paths") + } + autoUpdate := viper.GetBool("derp.auto_update_enabled") updateFrequency := viper.GetDuration("derp.update_frequency") return DERPConfig{ - ServerEnabled: serverEnabled, - ServerRegionID: serverRegionID, - ServerRegionCode: serverRegionCode, - ServerRegionName: serverRegionName, - ServerPrivateKeyPath: privateKeyPath, - STUNAddr: stunAddr, - URLs: urls, - Paths: paths, - AutoUpdate: autoUpdate, - UpdateFrequency: updateFrequency, + ServerEnabled: serverEnabled, + ServerRegionID: serverRegionID, + ServerRegionCode: serverRegionCode, + ServerRegionName: serverRegionName, + ServerPrivateKeyPath: privateKeyPath, + STUNAddr: stunAddr, + URLs: urls, + Paths: paths, + AutoUpdate: autoUpdate, + UpdateFrequency: updateFrequency, + IPv4: ipv4, + IPv6: ipv6, + AutomaticallyAddEmbeddedDerpRegion: automaticallyAddEmbeddedDerpRegion, } } @@ -379,6 +433,45 @@ func GetLogConfig() LogConfig { } } +func GetDatabaseConfig() DatabaseConfig { + debug := viper.GetBool("database.debug") + + type_ := viper.GetString("database.type") + + switch type_ { + case DatabaseSqlite, DatabasePostgres: + break + case "sqlite": + type_ = "sqlite3" + default: + log.Fatal(). + Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_) + } + + return DatabaseConfig{ + Type: type_, + Debug: debug, + Sqlite: SqliteConfig{ + Path: util.AbsolutePathFromConfigPath( + viper.GetString("database.sqlite.path"), + ), + }, + Postgres: PostgresConfig{ + Host: viper.GetString("database.postgres.host"), + Port: viper.GetInt("database.postgres.port"), + Name: viper.GetString("database.postgres.name"), + User: viper.GetString("database.postgres.user"), + Pass: viper.GetString("database.postgres.pass"), + Ssl: viper.GetString("database.postgres.ssl"), + MaxOpenConnections: viper.GetInt("database.postgres.max_open_conns"), + MaxIdleConnections: viper.GetInt("database.postgres.max_idle_conns"), + ConnMaxIdleTimeSecs: viper.GetInt( + "database.postgres.conn_max_idle_time_secs", + ), + }, + } +} + func GetDNSConfig() (*tailcfg.DNSConfig, string) { if viper.IsSet("dns_config") { dnsConfig := &tailcfg.DNSConfig{} @@ -580,7 +673,7 @@ func GetHeadscaleConfig() (*Config, error) { if err != nil { return nil, err } - oidcClientSecret = string(secretBytes) + oidcClientSecret = strings.TrimSpace(string(secretBytes)) } return &Config{ @@ -607,14 +700,7 @@ func GetHeadscaleConfig() (*Config, error) { "node_update_check_interval", ), - DBtype: viper.GetString("db_type"), - DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")), - DBhost: viper.GetString("db_host"), - DBport: viper.GetInt("db_port"), - DBname: viper.GetString("db_name"), - DBuser: viper.GetString("db_user"), - DBpass: viper.GetString("db_pass"), - DBssl: viper.GetString("db_ssl"), + Database: GetDatabaseConfig(), TLS: GetTLSConfig(), diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 9b2ba76..4434264 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -383,7 +383,7 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri // inform peers about smaller changes to the node. // When a field is added to this function, remember to also add it to: // - node.ApplyPeerChange -// - logTracePeerChange in poll.go +// - logTracePeerChange in poll.go. func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange { ret := tailcfg.PeerChange{ NodeID: tailcfg.NodeID(node.ID), diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index 7e6c984..712a839 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -6,6 +6,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/juanfont/headscale/hscontrol/util" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -366,3 +367,110 @@ func TestPeerChangeFromMapRequest(t *testing.T) { }) } } + +func TestApplyPeerChange(t *testing.T) { + tests := []struct { + name string + nodeBefore Node + change *tailcfg.PeerChange + want Node + }{ + { + name: "hostinfo-and-netinfo-not-exists", + nodeBefore: Node{}, + change: &tailcfg.PeerChange{ + DERPRegion: 1, + }, + want: Node{ + Hostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 1, + }, + }, + }, + }, + { + name: "hostinfo-netinfo-not-exists", + nodeBefore: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + }, + }, + change: &tailcfg.PeerChange{ + DERPRegion: 3, + }, + want: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 3, + }, + }, + }, + }, + { + name: "hostinfo-netinfo-exists-derp-set", + nodeBefore: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 999, + }, + }, + }, + change: &tailcfg.PeerChange{ + DERPRegion: 2, + }, + want: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 2, + }, + }, + }, + }, + { + name: "endpoints-not-set", + nodeBefore: Node{}, + change: &tailcfg.PeerChange{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + want: Node{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + }, + { + name: "endpoints-set", + nodeBefore: Node{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("6.6.6.6:66"), + }, + }, + change: &tailcfg.PeerChange{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + want: Node{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.nodeBefore.ApplyPeerChange(tt.change) + + if diff := cmp.Diff(tt.want, tt.nodeBefore, util.Comparers...); diff != "" { + t.Errorf("Patch unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 7f6b40e..0b8324f 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -2,7 +2,6 @@ package types import ( "strconv" - "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" @@ -22,12 +21,13 @@ type User struct { func (n *User) TailscaleUser() *tailcfg.User { user := tailcfg.User{ - ID: tailcfg.UserID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, + ID: tailcfg.UserID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + // TODO(kradalby): See if we can fill in Gravatar here ProfilePicURL: "", Logins: []tailcfg.LoginID{}, - Created: time.Time{}, + Created: n.CreatedAt, } return &user @@ -35,9 +35,10 @@ func (n *User) TailscaleUser() *tailcfg.User { func (n *User) TailscaleLogin() *tailcfg.Login { login := tailcfg.Login{ - ID: tailcfg.LoginID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, + ID: tailcfg.LoginID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + // TODO(kradalby): See if we can fill in Gravatar here ProfilePicURL: "", } diff --git a/hscontrol/util/test.go b/hscontrol/util/test.go index 6d46542..0a23acb 100644 --- a/hscontrol/util/test.go +++ b/hscontrol/util/test.go @@ -15,6 +15,10 @@ var IPComparer = cmp.Comparer(func(x, y netip.Addr) bool { return x.Compare(y) == 0 }) +var AddrPortComparer = cmp.Comparer(func(x, y netip.AddrPort) bool { + return x == y +}) + var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool { return x.String() == y.String() }) @@ -28,5 +32,5 @@ var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool { }) var Comparers []cmp.Option = []cmp.Option{ - IPComparer, PrefixComparer, MkeyComparer, NkeyComparer, DkeyComparer, + IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer, } diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index a861fd6..e51ccd1 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -265,6 +265,8 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) @@ -325,6 +327,8 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 90ce571..aa589fa 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -53,6 +53,8 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) @@ -90,6 +92,8 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) diff --git a/integration/cli_test.go b/integration/cli_test.go index d2d741e..e6190fb 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1240,7 +1240,7 @@ func TestNodeRenameCommand(t *testing.T) { assert.Contains(t, listAll[4].GetGivenName(), "node-5") for idx := 0; idx < 3; idx++ { - _, err := headscale.Execute( + res, err := headscale.Execute( []string{ "headscale", "nodes", @@ -1251,6 +1251,8 @@ func TestNodeRenameCommand(t *testing.T) { }, ) assert.Nil(t, err) + + assert.Contains(t, res, "Node renamed") } var listAllAfterRename []v1.Node diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index 3a40749..15ab7ad 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -33,20 +33,23 @@ func TestDERPServerScenario(t *testing.T) { defer scenario.Shutdown() spec := map[string]int{ - "user1": len(MustTestVersions), + "user1": 10, + // "user1": len(MustTestVersions), } - headscaleConfig := map[string]string{} - headscaleConfig["HEADSCALE_DERP_URLS"] = "" - headscaleConfig["HEADSCALE_DERP_SERVER_ENABLED"] = "true" - headscaleConfig["HEADSCALE_DERP_SERVER_REGION_ID"] = "999" - headscaleConfig["HEADSCALE_DERP_SERVER_REGION_CODE"] = "headscale" - headscaleConfig["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP" - headscaleConfig["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478" - headscaleConfig["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key" - // Envknob for enabling DERP debug logs - headscaleConfig["DERP_DEBUG_LOGS"] = "true" - headscaleConfig["DERP_PROBER_DEBUG_LOGS"] = "true" + headscaleConfig := map[string]string{ + "HEADSCALE_DERP_URLS": "", + "HEADSCALE_DERP_SERVER_ENABLED": "true", + "HEADSCALE_DERP_SERVER_REGION_ID": "999", + "HEADSCALE_DERP_SERVER_REGION_CODE": "headscale", + "HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP", + "HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478", + "HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key", + + // Envknob for enabling DERP debug logs + "DERP_DEBUG_LOGS": "true", + "DERP_PROBER_DEBUG_LOGS": "true", + } err = scenario.CreateHeadscaleEnv( spec, @@ -67,6 +70,8 @@ func TestDERPServerScenario(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allHostnames, err := scenario.ListTailscaleClientsFQDNs() assertNoErrListFQDN(t, err) diff --git a/integration/general_test.go b/integration/general_test.go index c092844..9aae26f 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -26,12 +26,34 @@ func TestPingAllByIP(t *testing.T) { assertNoErr(t, err) defer scenario.Shutdown() + // TODO(kradalby): it does not look like the user thing works, only second + // get created? maybe only when many? spec := map[string]int{ "user1": len(MustTestVersions), "user2": len(MustTestVersions), } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyip")) + headscaleConfig := map[string]string{ + "HEADSCALE_DERP_URLS": "", + "HEADSCALE_DERP_SERVER_ENABLED": "true", + "HEADSCALE_DERP_SERVER_REGION_ID": "999", + "HEADSCALE_DERP_SERVER_REGION_CODE": "headscale", + "HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP", + "HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478", + "HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key", + + // Envknob for enabling DERP debug logs + "DERP_DEBUG_LOGS": "true", + "DERP_PROBER_DEBUG_LOGS": "true", + } + + err = scenario.CreateHeadscaleEnv(spec, + []tsic.Option{}, + hsic.WithTestName("pingallbyip"), + hsic.WithConfigEnv(headscaleConfig), + hsic.WithTLS(), + hsic.WithHostnameAsServerURL(), + ) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() @@ -43,6 +65,46 @@ func TestPingAllByIP(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) +} + +func TestPingAllByIPPublicDERP(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + scenario, err := NewScenario() + assertNoErr(t, err) + defer scenario.Shutdown() + + spec := map[string]int{ + "user1": len(MustTestVersions), + "user2": len(MustTestVersions), + } + + err = scenario.CreateHeadscaleEnv(spec, + []tsic.Option{}, + hsic.WithTestName("pingallbyippubderp"), + ) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + assertNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) @@ -73,6 +135,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + clientIPs := make(map[TailscaleClient][]netip.Addr) for _, client := range allClients { ips, err := client.IPs() @@ -112,6 +176,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allClients, err = scenario.ListTailscaleClients() assertNoErrListClients(t, err) @@ -263,6 +329,8 @@ func TestPingAllByHostname(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allHostnames, err := scenario.ListTailscaleClientsFQDNs() assertNoErrListFQDN(t, err) @@ -320,9 +388,13 @@ func TestTaildrop(t *testing.T) { if err != nil { t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err) } - } - curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"} + curlCommand := []string{ + "curl", + "--unix-socket", + "/var/run/tailscale/tailscaled.sock", + "http://local-tailscaled.sock/localapi/v0/file-targets", + } err = retry(10, 1*time.Second, func() error { result, _, err := client.Execute(curlCommand) if err != nil { @@ -339,13 +411,23 @@ func TestTaildrop(t *testing.T) { for _, ft := range fts { ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name) } - return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr) + return fmt.Errorf( + "client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", + client.Hostname(), + len(fts), + len(allClients)-1, + ftStr, + ) } return err }) if err != nil { - t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err) + t.Errorf( + "failed to query localapi for filetarget on %s, err: %s", + client.Hostname(), + err, + ) } } @@ -457,6 +539,8 @@ func TestResolveMagicDNS(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + // Poor mans cache _, err = scenario.ListTailscaleClientsFQDNs() assertNoErrListFQDN(t, err) @@ -525,6 +609,8 @@ func TestExpireNode(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) @@ -546,7 +632,7 @@ func TestExpireNode(t *testing.T) { // TODO(kradalby): This is Headscale specific and would not play nicely // with other implementations of the ControlServer interface result, err := headscale.Execute([]string{ - "headscale", "nodes", "expire", "--identifier", "0", "--output", "json", + "headscale", "nodes", "expire", "--identifier", "1", "--output", "json", }) assertNoErr(t, err) @@ -577,16 +663,38 @@ func TestExpireNode(t *testing.T) { assertNotNil(t, peerStatus.Expired) assert.NotNil(t, peerStatus.KeyExpiry) - t.Logf("node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) + t.Logf( + "node %q should have a key expire before %s, was %s", + peerStatus.HostName, + now.String(), + peerStatus.KeyExpiry, + ) if peerStatus.KeyExpiry != nil { - assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) + assert.Truef( + t, + peerStatus.KeyExpiry.Before(now), + "node %q should have a key expire before %s, was %s", + peerStatus.HostName, + now.String(), + peerStatus.KeyExpiry, + ) } - assert.Truef(t, peerStatus.Expired, "node %q should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired) + assert.Truef( + t, + peerStatus.Expired, + "node %q should be expired, expired is %v", + peerStatus.HostName, + peerStatus.Expired, + ) _, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()}) if !strings.Contains(stderr, "node key has expired") { - t.Errorf("expected to be unable to ping expired host %q from %q", node.GetName(), client.Hostname()) + t.Errorf( + "expected to be unable to ping expired host %q from %q", + node.GetName(), + client.Hostname(), + ) } } else { t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey) @@ -598,7 +706,7 @@ func TestExpireNode(t *testing.T) { // NeedsLogin means that the node has understood that it is no longer // valid. - assert.Equal(t, "NeedsLogin", status.BackendState) + assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName) } } } @@ -627,6 +735,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + assertClientsState(t, allClients) + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) @@ -691,7 +801,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) { assert.Truef( t, lastSeen.After(lastSeenThreshold), - "lastSeen (%v) was not %s after the threshold (%v)", + "node (%s) lastSeen (%v) was not %s after the threshold (%v)", + node.GetName(), lastSeen, keepAliveInterval, lastSeenThreshold, diff --git a/integration/hsic/config.go b/integration/hsic/config.go index 00c1770..819b108 100644 --- a/integration/hsic/config.go +++ b/integration/hsic/config.go @@ -110,8 +110,8 @@ func DefaultConfigEnv() map[string]string { return map[string]string{ "HEADSCALE_LOG_LEVEL": "trace", "HEADSCALE_ACL_POLICY_PATH": "", - "HEADSCALE_DB_TYPE": "sqlite3", - "HEADSCALE_DB_PATH": "/tmp/integration_test_db.sqlite3", + "HEADSCALE_DATABASE_TYPE": "sqlite", + "HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3", "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m", "HEADSCALE_NODE_UPDATE_CHECK_INTERVAL": "10s", "HEADSCALE_IP_PREFIXES": "fd7a:115c:a1e0::/48 100.64.0.0/10", diff --git a/integration/route_test.go b/integration/route_test.go index 489165a..75296fd 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -9,10 +9,15 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "tailscale.com/types/ipproto" + "tailscale.com/wgengine/filter" ) // This test is both testing the routes command and the propagation of @@ -83,9 +88,9 @@ func TestEnablingRoutes(t *testing.T) { assert.Len(t, routes, 3) for _, route := range routes { - assert.Equal(t, route.GetAdvertised(), true) - assert.Equal(t, route.GetEnabled(), false) - assert.Equal(t, route.GetIsPrimary(), false) + assert.Equal(t, true, route.GetAdvertised()) + assert.Equal(t, false, route.GetEnabled()) + assert.Equal(t, false, route.GetIsPrimary()) } // Verify that no routes has been sent to the client, @@ -130,9 +135,9 @@ func TestEnablingRoutes(t *testing.T) { assert.Len(t, enablingRoutes, 3) for _, route := range enablingRoutes { - assert.Equal(t, route.GetAdvertised(), true) - assert.Equal(t, route.GetEnabled(), true) - assert.Equal(t, route.GetIsPrimary(), true) + assert.Equal(t, true, route.GetAdvertised()) + assert.Equal(t, true, route.GetEnabled()) + assert.Equal(t, true, route.GetIsPrimary()) } time.Sleep(5 * time.Second) @@ -186,6 +191,8 @@ func TestEnablingRoutes(t *testing.T) { }) assertNoErr(t, err) + time.Sleep(5 * time.Second) + var disablingRoutes []*v1.Route err = executeAndUnmarshal( headscale, @@ -204,16 +211,14 @@ func TestEnablingRoutes(t *testing.T) { assert.Equal(t, true, route.GetAdvertised()) if route.GetId() == routeToBeDisabled.GetId() { - assert.Equal(t, route.GetEnabled(), false) - assert.Equal(t, route.GetIsPrimary(), false) + assert.Equal(t, false, route.GetEnabled()) + assert.Equal(t, false, route.GetIsPrimary()) } else { - assert.Equal(t, route.GetEnabled(), true) - assert.Equal(t, route.GetIsPrimary(), true) + assert.Equal(t, true, route.GetEnabled()) + assert.Equal(t, true, route.GetIsPrimary()) } } - time.Sleep(5 * time.Second) - // Verify that the clients can see the new routes for _, client := range allClients { status, err := client.Status() @@ -289,7 +294,7 @@ func TestHASubnetRouterFailover(t *testing.T) { // advertise HA route on node 1 and 2 // ID 1 will be primary // ID 2 will be secondary - for _, client := range allClients { + for _, client := range allClients[:2] { status, err := client.Status() assertNoErr(t, err) @@ -301,6 +306,8 @@ func TestHASubnetRouterFailover(t *testing.T) { } _, _, err = client.Execute(command) assertNoErrf(t, "failed to advertise route: %s", err) + } else { + t.Fatalf("failed to find route for Node %s (id: %s)", status.Self.HostName, status.Self.ID) } } @@ -323,6 +330,8 @@ func TestHASubnetRouterFailover(t *testing.T) { assertNoErr(t, err) assert.Len(t, routes, 2) + t.Logf("initial routes %#v", routes) + for _, route := range routes { assert.Equal(t, true, route.GetAdvertised()) assert.Equal(t, false, route.GetEnabled()) @@ -639,6 +648,8 @@ func TestHASubnetRouterFailover(t *testing.T) { assertNoErr(t, err) assert.Len(t, routesAfterDisabling1, 2) + t.Logf("routes after disabling1 %#v", routesAfterDisabling1) + // Node 1 is not primary assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised()) assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled()) @@ -778,3 +789,413 @@ func TestHASubnetRouterFailover(t *testing.T) { ) } } + +func TestEnableDisableAutoApprovedRoute(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + expectedRoutes := "172.0.0.0/24" + + user := "enable-disable-routing" + + scenario, err := NewScenario() + assertNoErrf(t, "failed to create scenario: %s", err) + defer scenario.Shutdown() + + spec := map[string]int{ + user: 1, + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:approve"})}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy( + &policy.ACLPolicy{ + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + TagOwners: map[string][]string{ + "tag:approve": {user}, + }, + AutoApprovers: policy.AutoApprovers{ + Routes: map[string][]string{ + expectedRoutes: {"tag:approve"}, + }, + }, + }, + )) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + subRouter1 := allClients[0] + + // Initially advertise route + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + expectedRoutes, + } + _, _, err = subRouter1.Execute(command) + assertNoErrf(t, "failed to advertise route: %s", err) + + time.Sleep(10 * time.Second) + + var routes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routes, + ) + assertNoErr(t, err) + assert.Len(t, routes, 1) + + // All routes should be auto approved and enabled + assert.Equal(t, true, routes[0].GetAdvertised()) + assert.Equal(t, true, routes[0].GetEnabled()) + assert.Equal(t, true, routes[0].GetIsPrimary()) + + // Stop advertising route + command = []string{ + "tailscale", + "set", + "--advertise-routes=", + } + _, _, err = subRouter1.Execute(command) + assertNoErrf(t, "failed to remove advertised route: %s", err) + + time.Sleep(10 * time.Second) + + var notAdvertisedRoutes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + ¬AdvertisedRoutes, + ) + assertNoErr(t, err) + assert.Len(t, notAdvertisedRoutes, 1) + + // Route is no longer advertised + assert.Equal(t, false, notAdvertisedRoutes[0].GetAdvertised()) + assert.Equal(t, false, notAdvertisedRoutes[0].GetEnabled()) + assert.Equal(t, true, notAdvertisedRoutes[0].GetIsPrimary()) + + // Advertise route again + command = []string{ + "tailscale", + "set", + "--advertise-routes=" + expectedRoutes, + } + _, _, err = subRouter1.Execute(command) + assertNoErrf(t, "failed to advertise route: %s", err) + + time.Sleep(10 * time.Second) + + var reAdvertisedRoutes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &reAdvertisedRoutes, + ) + assertNoErr(t, err) + assert.Len(t, reAdvertisedRoutes, 1) + + // All routes should be auto approved and enabled + assert.Equal(t, true, reAdvertisedRoutes[0].GetAdvertised()) + assert.Equal(t, true, reAdvertisedRoutes[0].GetEnabled()) + assert.Equal(t, true, reAdvertisedRoutes[0].GetIsPrimary()) +} + +// TestSubnetRouteACL verifies that Subnet routes are distributed +// as expected when ACLs are activated. +// It implements the issue from +// https://github.com/juanfont/headscale/issues/1604 +func TestSubnetRouteACL(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + user := "subnet-route-acl" + + scenario, err := NewScenario() + assertNoErrf(t, "failed to create scenario: %s", err) + defer scenario.Shutdown() + + spec := map[string]int{ + user: 2, + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy( + &policy.ACLPolicy{ + Groups: policy.Groups{ + "group:admins": {user}, + }, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"group:admins"}, + Destinations: []string{"group:admins:*"}, + }, + { + Action: "accept", + Sources: []string{"group:admins"}, + Destinations: []string{"10.33.0.0/16:*"}, + }, + // { + // Action: "accept", + // Sources: []string{"group:admins"}, + // Destinations: []string{"0.0.0.0/0:*"}, + // }, + }, + }, + )) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + expectedRoutes := map[string]string{ + "1": "10.33.0.0/16", + } + + // Sort nodes by ID + sort.SliceStable(allClients, func(i, j int) bool { + statusI, err := allClients[i].Status() + if err != nil { + return false + } + + statusJ, err := allClients[j].Status() + if err != nil { + return false + } + + return statusI.Self.ID < statusJ.Self.ID + }) + + subRouter1 := allClients[0] + + client := allClients[1] + + // advertise HA route on node 1 and 2 + // ID 1 will be primary + // ID 2 will be secondary + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + if route, ok := expectedRoutes[string(status.Self.ID)]; ok { + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + route, + } + _, _, err = client.Execute(command) + assertNoErrf(t, "failed to advertise route: %s", err) + } + } + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + var routes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routes, + ) + + assertNoErr(t, err) + assert.Len(t, routes, 1) + + for _, route := range routes { + assert.Equal(t, true, route.GetAdvertised()) + assert.Equal(t, false, route.GetEnabled()) + assert.Equal(t, false, route.GetIsPrimary()) + } + + // Verify that no routes has been sent to the client, + // they are not yet enabled. + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Nil(t, peerStatus.PrimaryRoutes) + } + } + + // Enable all routes + for _, route := range routes { + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "enable", + "--route", + strconv.Itoa(int(route.GetId())), + }) + assertNoErr(t, err) + } + + time.Sleep(5 * time.Second) + + var enablingRoutes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &enablingRoutes, + ) + assertNoErr(t, err) + assert.Len(t, enablingRoutes, 1) + + // Node 1 has active route + assert.Equal(t, true, enablingRoutes[0].GetAdvertised()) + assert.Equal(t, true, enablingRoutes[0].GetEnabled()) + assert.Equal(t, true, enablingRoutes[0].GetIsPrimary()) + + // Verify that the client has routes from the primary machine + srs1, _ := subRouter1.Status() + + clientStatus, err := client.Status() + assertNoErr(t, err) + + srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] + + assertNotNil(t, srs1PeerStatus.PrimaryRoutes) + + t.Logf("subnet1 has following routes: %v", srs1PeerStatus.PrimaryRoutes.AsSlice()) + assert.Len(t, srs1PeerStatus.PrimaryRoutes.AsSlice(), 1) + assert.Contains( + t, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), + ) + + clientNm, err := client.Netmap() + assertNoErr(t, err) + + wantClientFilter := []filter.Match{ + { + IPProto: []ipproto.Proto{ + ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, + }, + Srcs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + netip.MustParsePrefix("100.64.0.2/32"), + netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + }, + Dsts: []filter.NetPortRange{ + { + Net: netip.MustParsePrefix("100.64.0.2/32"), + Ports: filter.PortRange{0, 0xffff}, + }, + { + Net: netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + Ports: filter.PortRange{0, 0xffff}, + }, + }, + Caps: []filter.CapMatch{}, + }, + } + + if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.PrefixComparer); diff != "" { + t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff) + } + + subnetNm, err := subRouter1.Netmap() + assertNoErr(t, err) + + wantSubnetFilter := []filter.Match{ + { + IPProto: []ipproto.Proto{ + ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, + }, + Srcs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + netip.MustParsePrefix("100.64.0.2/32"), + netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + }, + Dsts: []filter.NetPortRange{ + { + Net: netip.MustParsePrefix("100.64.0.1/32"), + Ports: filter.PortRange{0, 0xffff}, + }, + { + Net: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + Ports: filter.PortRange{0, 0xffff}, + }, + }, + Caps: []filter.CapMatch{}, + }, + { + IPProto: []ipproto.Proto{ + ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, + }, + Srcs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + netip.MustParsePrefix("100.64.0.2/32"), + netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + }, + Dsts: []filter.NetPortRange{ + { + Net: netip.MustParsePrefix("10.33.0.0/16"), + Ports: filter.PortRange{0, 0xffff}, + }, + }, + Caps: []filter.CapMatch{}, + }, + } + + if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.PrefixComparer); diff != "" { + t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff) + } +} diff --git a/integration/scenario.go b/integration/scenario.go index c11af72..16ec6f4 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -56,8 +56,8 @@ var ( "1.44": true, // CapVer: 63 "1.42": true, // CapVer: 61 "1.40": true, // CapVer: 61 - "1.38": true, // CapVer: 58 - "1.36": true, // Oldest supported version, CapVer: 56 + "1.38": true, // Oldest supported version, CapVer: 58 + "1.36": false, // CapVer: 56 "1.34": false, // CapVer: 51 "1.32": false, // CapVer: 46 "1.30": false, diff --git a/integration/tailscale.go b/integration/tailscale.go index e7bf71b..9d6796b 100644 --- a/integration/tailscale.go +++ b/integration/tailscale.go @@ -7,6 +7,8 @@ import ( "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/tsic" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/netcheck" + "tailscale.com/types/netmap" ) // nolint @@ -26,6 +28,8 @@ type TailscaleClient interface { IPs() ([]netip.Addr, error) FQDN() (string, error) Status() (*ipnstate.Status, error) + Netmap() (*netmap.NetworkMap, error) + Netcheck() (*netcheck.Report, error) WaitForNeedsLogin() error WaitForRunning() error WaitForPeers(expected int) error diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index 7404f6e..854d5a7 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -17,6 +17,8 @@ import ( "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/netcheck" + "tailscale.com/types/netmap" ) const ( @@ -519,6 +521,53 @@ func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) { return &status, err } +// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance. +// Only works with Tailscale 1.56.1 and newer. +func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { + command := []string{ + "tailscale", + "debug", + "netmap", + } + + result, stderr, err := t.Execute(command) + if err != nil { + fmt.Printf("stderr: %s\n", stderr) + return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err) + } + + var nm netmap.NetworkMap + err = json.Unmarshal([]byte(result), &nm) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err) + } + + return &nm, err +} + +// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance. +func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) { + command := []string{ + "tailscale", + "netcheck", + "--format=json", + } + + result, stderr, err := t.Execute(command) + if err != nil { + fmt.Printf("stderr: %s\n", stderr) + return nil, fmt.Errorf("failed to execute tailscale debug netcheck command: %w", err) + } + + var nm netcheck.Report + err = json.Unmarshal([]byte(result), &nm) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal tailscale netcheck: %w", err) + } + + return &nm, err +} + // FQDN returns the FQDN as a string of the Tailscale instance. func (t *TailscaleInContainer) FQDN() (string, error) { if t.fqdn != "" { @@ -623,12 +672,22 @@ func (t *TailscaleInContainer) WaitForPeers(expected int) error { len(peers), ) } else { + // Verify that the peers of a given node is Online + // has a hostname and a DERP relay. for _, peerKey := range peers { peer := status.Peer[peerKey] if !peer.Online { return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName) } + + if peer.HostName == "" { + return fmt.Errorf("[%s] peer count correct, but %s does not have a Hostname", t.hostname, peer.HostName) + } + + if peer.Relay == "" { + return fmt.Errorf("[%s] peer count correct, but %s does not have a DERP", t.hostname, peer.HostName) + } } } diff --git a/integration/utils.go b/integration/utils.go index e17e18a..ae4441b 100644 --- a/integration/utils.go +++ b/integration/utils.go @@ -7,6 +7,8 @@ import ( "time" "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/assert" + "tailscale.com/util/cmpver" ) const ( @@ -83,7 +85,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts for _, addr := range addrs { err := client.Ping(addr, opts...) if err != nil { - t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err) + t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err) } else { success++ } @@ -120,6 +122,148 @@ func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) return success } +// assertClientsState validates the status and netmap of a list of +// clients for the general case of all to all connectivity. +func assertClientsState(t *testing.T, clients []TailscaleClient) { + t.Helper() + + for _, client := range clients { + assertValidStatus(t, client) + assertValidNetmap(t, client) + assertValidNetcheck(t, client) + } +} + +// assertValidNetmap asserts that the netmap of a client has all +// the minimum required fields set to a known working config for +// the general case. Fields are checked on self, then all peers. +// This test is not suitable for ACL/partial connection tests. +// This test can only be run on clients from 1.56.1. It will +// automatically pass all clients below that and is safe to call +// for all versions. +func assertValidNetmap(t *testing.T, client TailscaleClient) { + t.Helper() + + if cmpver.Compare("1.56.1", client.Version()) <= 0 || + !strings.Contains(client.Hostname(), "unstable") || + !strings.Contains(client.Hostname(), "head") { + return + } + + netmap, err := client.Netmap() + if err != nil { + t.Fatalf("getting netmap for %q: %s", client.Hostname(), err) + } + + assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) + if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { + assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) + } + + assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) + assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) + + assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname()) + + assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) + assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) + assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) + + for _, peer := range netmap.Peers { + assert.NotEqualf(t, "127.3.3.40:0", peer.DERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.DERP()) + + assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) + if hi := peer.Hostinfo(); hi.Valid() { + assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) + + // Netinfo is not always set + assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname()) + if ni := hi.NetInfo(); ni.Valid() { + assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP()) + } + + } + + assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname()) + + assert.Truef(t, *peer.Online(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname()) + + assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname()) + } +} + +// assertValidStatus asserts that the status of a client has all +// the minimum required fields set to a known working config for +// the general case. Fields are checked on self, then all peers. +// This test is not suitable for ACL/partial connection tests. +func assertValidStatus(t *testing.T, client TailscaleClient) { + t.Helper() + status, err := client.Status() + if err != nil { + t.Fatalf("getting status for %q: %s", client.Hostname(), err) + } + + assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname()) + assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname()) + assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname()) + + assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname()) + + // This seem to not appear until version 1.56 + if status.Self.AllowedIPs != nil { + assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname()) + } + + assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname()) + + assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname()) + + assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname()) + + // This isnt really relevant for Self as it wont be in its own socket/wireguard. + // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) + // assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname()) + + for _, peer := range status.Peer { + assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) + assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname()) + assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname()) + + assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname()) + + // This seem to not appear until version 1.56 + if peer.AllowedIPs != nil { + assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname()) + } + + // Addrs does not seem to appear in the status from peers. + // assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname()) + + assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname()) + + assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname()) + assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname()) + + // TODO(kradalby): InEngine is only true when a proper tunnel is set up, + // there might be some interesting stuff to test here in the future. + // assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname()) + } +} + +func assertValidNetcheck(t *testing.T, client TailscaleClient) { + t.Helper() + report, err := client.Netcheck() + if err != nil { + t.Fatalf("getting status for %q: %s", client.Hostname(), err) + } + + assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname()) +} + func isSelfClient(client TailscaleClient, addr string) bool { if addr == client.Hostname() { return true @@ -152,7 +296,7 @@ func isCI() bool { } func dockertestMaxWait() time.Duration { - wait := 60 * time.Second //nolint + wait := 120 * time.Second //nolint if isCI() { wait = 300 * time.Second //nolint