Compare commits

...

14 commits

Author SHA1 Message Date
Ty
cc430af3f9
Merge upstream/main into fork 2024-02-11 21:48:13 -07:00
lööps
c3257e2146
docs(windows-client): add Windows registry command (#1658)
Add Windows registry command to create the `Tailscale IPN` path before setting properties.
2024-02-09 19:16:17 +01:00
Pallab Pain
9047c09871
feat: add pqsql configs for open and idle connections (#1583)
When Postgres is used as the backing database for headscale,
it does not set a limit on maximum open and idle connections
which leads to hundreds of open connections to the Postgres
server.

This commit introduces the configuration variables to set those
values and also sets default while opening a new postgres connection.
2024-02-09 17:34:28 +01:00
Kristoffer Dalby
91bb85e7d2
Update bug_report.md (#1672) 2024-02-09 07:27:13 +01:00
Kristoffer Dalby
94b30abf56
Restructure database config (#1700) 2024-02-09 07:27:00 +01:00
Kristoffer Dalby
00e7550e76
Add assert func for verifying status, netmap and netcheck (#1723) 2024-02-09 07:26:41 +01:00
Kristoffer Dalby
83769ba715
Replace database locks with transactions (#1701)
This commits removes the locks used to guard data integrity for the
database and replaces them with Transactions, turns out that SQL had
a way to deal with this all along.

This reduces the complexity we had with multiple locks that might stack
or recurse (database, nofitifer, mapper). All notifications and state
updates are now triggered _after_ a database change.


Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-02-08 17:28:19 +01:00
DeveloperDragon
cbf57e27a7
Login with OIDC after having been logged out (#1719) 2024-02-05 10:45:35 +01:00
derelm
4ea12f472a
Fix failover to disabled route #1706 (#1707)
* fix #1706 - failover should disregard disabled routes during failover

* fixe tests for failover; all current tests assume routes to be enabled

* add testcase for #1706 - failover to disabled route
2024-02-03 15:30:15 +01:00
danielalvsaaker
b4210e2c90
Trim client secret after reading from file (#1697)
Reading from file will include a line break, which results in a mismatching client secret
compared to reading directly from the config.
2024-01-25 09:53:34 +01:00
dyz
a369d57a17
fix node expire error due to type in gorm model Update (#1692)
Fixes #1674

Signed-off-by: fortitude.zhang <fortitude.zhang@gmail.com>
2024-01-21 17:38:24 +01:00
Kristoffer Dalby
1e22f17f36
node selfupdate and fix subnet router when ACL is enabled (#1673)
Fixes #1604

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-01-18 17:30:25 +01:00
Kristoffer Dalby
65376e2842
ensure renabled auto-approve routes works (#1670) 2024-01-18 16:36:47 +01:00
Alexander Halbarth
7e8bf4bfe5
Add Customization Options to DERP Map entry of integrated DERP server (#1565)
Co-authored-by: Alexander Halbarth <alexander.halbarth@alite.at>
Co-authored-by: Bela Lemle <bela.lemle@alite.at>
Co-authored-by: Kristoffer Dalby <kristoffer@dalby.cc>
2024-01-16 16:04:03 +01:00
52 changed files with 3272 additions and 1415 deletions

View file

@ -50,3 +50,16 @@ instead of filing a bug report.
## To Reproduce ## To Reproduce
<!-- Steps to reproduce the behavior. --> <!-- Steps to reproduce the behavior. -->
## Logs and attachments
<!-- Please attach files with:
- Client netmap dump (see below)
- ACL configuration
- Headscale configuration
Dump the netmap of tailscale clients:
`tailscale debug netmap > DESCRIPTIVE_NAME.json`
Please provide information describing the netmap, which client, which headscale version etc.
-->

View file

@ -0,0 +1,67 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestEnableDisableAutoApprovedRoute
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestEnableDisableAutoApprovedRoute:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestEnableDisableAutoApprovedRoute
uses: Wandalen/wretry.action@master
if: steps.changed-files.outputs.any_changed == 'true'
with:
attempt_limit: 5
command: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestEnableDisableAutoApprovedRoute$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -0,0 +1,67 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestPingAllByIPPublicDERP
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestPingAllByIPPublicDERP:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestPingAllByIPPublicDERP
uses: Wandalen/wretry.action@master
if: steps.changed-files.outputs.any_changed == 'true'
with:
attempt_limit: 5
command: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestPingAllByIPPublicDERP$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -0,0 +1,67 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestSubnetRouteACL
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestSubnetRouteACL:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestSubnetRouteACL
uses: Wandalen/wretry.action@master
if: steps.changed-files.outputs.any_changed == 'true'
with:
attempt_limit: 5
command: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestSubnetRouteACL$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -26,7 +26,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/
- Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473) - Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473)
- API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553) - API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553)
- Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611) - Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611)
- The latest supported client is 1.36 - The latest supported client is 1.38
- Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564) - Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564)
- If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url. - If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url.
- Embedded DERP server requires a private key [#1611](https://github.com/juanfont/headscale/pull/1611) - Embedded DERP server requires a private key [#1611](https://github.com/juanfont/headscale/pull/1611)
@ -34,17 +34,21 @@ after improving the test harness as part of adopting [#1460](https://github.com/
### Changes ### Changes
Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644) - Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644)
Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484) - Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287) - SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
SSH support [#1487](https://github.com/juanfont/headscale/pull/1487) - State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492) - Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460)
Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460) - Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) - Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) - Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) - Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259)
Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259) - Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565)
Add `oidc.groups_claim`, `oidc.email_claim`, and `oidc.username_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594) - Change the structure of database configuration, see [config-example.yaml](./config-example.yaml) for the new structure. [#1700](https://github.com/juanfont/headscale/pull/1700)
- Old structure is now considered deprecated and will be removed in the future.
- Adds additional configuration for PostgreSQL for setting max open, idle conection and idle connection lifetime.
- Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287)
- Add `oidc.groups_claim`, `oidc.email_claim`, and `oidc.username_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)
## 0.22.3 (2023-05-12) ## 0.22.3 (2023-05-12)

View file

@ -6,25 +6,11 @@ import (
"github.com/efekarakus/termcolor" "github.com/efekarakus/termcolor"
"github.com/juanfont/headscale/cmd/headscale/cli" "github.com/juanfont/headscale/cmd/headscale/cli"
"github.com/pkg/profile"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func main() { func main() {
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
err := os.MkdirAll(profilePath, os.ModePerm)
if err != nil {
log.Fatal().Err(err).Msg("failed to create profiling directory")
}
defer profile.Start(profile.ProfilePath(profilePath)).Stop()
} else {
defer profile.Start().Stop()
}
}
var colors bool var colors bool
switch l := termcolor.SupportLevel(os.Stderr); l { switch l := termcolor.SupportLevel(os.Stderr); l {
case termcolor.Level16M: case termcolor.Level16M:

View file

@ -58,8 +58,10 @@ func (*Suite) TestConfigFileLoading(c *check.C) {
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080") c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080") c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090") c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") c.Assert(viper.GetString("db_type"), check.Equals, "sqlite")
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite") c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
c.Assert(viper.GetString("database.type"), check.Equals, "sqlite")
c.Assert(viper.GetString("database.sqlite.path"), check.Equals, "/var/lib/headscale/db.sqlite")
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http") c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
@ -101,7 +103,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080") c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080") c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090") c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") c.Assert(viper.GetString("db_type"), check.Equals, "sqlite")
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite") c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http") c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")

View file

@ -94,6 +94,16 @@ derp:
# #
private_key_path: /var/lib/headscale/derp_server_private.key private_key_path: /var/lib/headscale/derp_server_private.key
# This flag can be used, so the DERP map entry for the embedded DERP server is not written automatically,
# it enables the creation of your very own DERP map entry using a locally available file with the parameter DERP.paths
# If you enable the DERP server and set this to false, it is required to add the DERP server to the DERP map using DERP.paths
automatically_add_embedded_derp_region: true
# For better connection stability (especially when using an Exit-Node and DNS is not working),
# it is possible to optionall add the public IPv4 and IPv6 address to the Derp-Map using:
ipv4: 1.2.3.4
ipv6: 2001:db8::1
# List of externally available DERP maps encoded in JSON # List of externally available DERP maps encoded in JSON
urls: urls:
- https://controlplane.tailscale.com/derpmap/default - https://controlplane.tailscale.com/derpmap/default
@ -128,24 +138,28 @@ ephemeral_node_inactivity_timeout: 30m
# In case of doubts, do not touch the default 10s. # In case of doubts, do not touch the default 10s.
node_update_check_interval: 10s node_update_check_interval: 10s
# SQLite config database:
db_type: sqlite3 type: sqlite
# For production: # SQLite config
db_path: /var/lib/headscale/db.sqlite sqlite:
path: /var/lib/headscale/db.sqlite
# # Postgres config # # Postgres config
# If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank. # postgres:
# db_type: postgres # # If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank.
# db_host: localhost # host: localhost
# db_port: 5432 # port: 5432
# db_name: headscale # name: headscale
# db_user: foo # user: foo
# db_pass: bar # pass: bar
# max_open_conns: 10
# max_idle_conns: 10
# conn_max_idle_time_secs: 3600
# If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need # # If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need
# in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1. # # in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1.
# db_ssl: false # ssl: false
### TLS configuration ### TLS configuration
# #

View file

@ -18,6 +18,7 @@ You can set these using the Windows Registry Editor:
Or via the following Powershell commands (right click Powershell icon and select "Run as administrator"): Or via the following Powershell commands (right click Powershell icon and select "Run as administrator"):
``` ```
New-Item -Path "HKLM:\SOFTWARE\Tailscale IPN"
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value https://YOUR-HEADSCALE-URL New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value https://YOUR-HEADSCALE-URL
``` ```

View file

@ -12,7 +12,6 @@ import (
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
"syscall" "syscall"
@ -33,6 +32,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog" zerolog "github.com/philip-bui/grpc-zerolog"
"github.com/pkg/profile"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
zl "github.com/rs/zerolog" zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -48,6 +48,7 @@ import (
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/reflection" "google.golang.org/grpc/reflection"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"gorm.io/gorm"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
@ -61,7 +62,7 @@ var (
"unknown value for Lets Encrypt challenge type", "unknown value for Lets Encrypt challenge type",
) )
errEmptyInitialDERPMap = errors.New( errEmptyInitialDERPMap = errors.New(
"initial DERPMap is empty, Headscale requries at least one entry", "initial DERPMap is empty, Headscale requires at least one entry",
) )
) )
@ -116,37 +117,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
} }
var dbString string
switch cfg.DBtype {
case db.Postgres:
dbString = fmt.Sprintf(
"host=%s dbname=%s user=%s",
cfg.DBhost,
cfg.DBname,
cfg.DBuser,
)
if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil {
if !sslEnabled {
dbString += " sslmode=disable"
}
} else {
dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl)
}
if cfg.DBport != 0 {
dbString += fmt.Sprintf(" port=%d", cfg.DBport)
}
if cfg.DBpass != "" {
dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
}
case db.Sqlite:
dbString = cfg.DBpath
default:
return nil, errUnsupportedDatabase
}
registrationCache := cache.New( registrationCache := cache.New(
registerCacheExpiration, registerCacheExpiration,
registerCacheCleanup, registerCacheCleanup,
@ -154,8 +124,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
app := Headscale{ app := Headscale{
cfg: cfg, cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
noisePrivateKey: noisePrivateKey, noisePrivateKey: noisePrivateKey,
registrationCache: registrationCache, registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{}, pollNetMapStreamWG: sync.WaitGroup{},
@ -163,9 +131,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
} }
database, err := db.NewHeadscaleDatabase( database, err := db.NewHeadscaleDatabase(
cfg.DBtype, cfg.Database,
dbString,
app.dbDebug,
app.nodeNotifier, app.nodeNotifier,
cfg.IPPrefixes, cfg.IPPrefixes,
cfg.BaseDomain) cfg.BaseDomain)
@ -234,8 +200,23 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout. // seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
var update types.StateUpdate
var changed bool
for range ticker.C { for range ticker.C {
h.db.ExpireEphemeralNodes(h.cfg.EphemeralNodeInactivityTimeout) if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
continue
}
if changed && update.Valid() {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, update)
}
} }
} }
@ -246,9 +227,24 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)
lastCheck := time.Unix(0, 0) lastCheck := time.Unix(0, 0)
var update types.StateUpdate
var changed bool
for range ticker.C { for range ticker.C {
lastCheck = h.db.ExpireExpiredNodes(lastCheck) if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring nodes")
continue
}
log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes")
if changed && update.Valid() {
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
}
} }
} }
@ -268,7 +264,7 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
case <-ticker.C: case <-ticker.C:
log.Info().Msg("Fetching DERPMap updates") log.Info().Msg("Fetching DERPMap updates")
h.DERPMap = derp.GetDERPMap(h.cfg.DERP) h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
if h.cfg.DERP.ServerEnabled { if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
region, _ := h.DERPServer.GenerateRegion() region, _ := h.DERPServer.GenerateRegion()
h.DERPMap.Regions[region.RegionID] = &region h.DERPMap.Regions[region.RegionID] = &region
} }
@ -278,7 +274,8 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
DERPMap: h.DERPMap, DERPMap: h.DERPMap,
} }
if stateUpdate.Valid() { if stateUpdate.Valid() {
h.nodeNotifier.NotifyAll(stateUpdate) ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
} }
} }
} }
@ -485,6 +482,19 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
// Serve launches a GIN server with the Headscale API. // Serve launches a GIN server with the Headscale API.
func (h *Headscale) Serve() error { func (h *Headscale) Serve() error {
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
err := os.MkdirAll(profilePath, os.ModePerm)
if err != nil {
log.Fatal().Err(err).Msg("failed to create profiling directory")
}
defer profile.Start(profile.ProfilePath(profilePath)).Stop()
} else {
defer profile.Start().Stop()
}
}
var err error var err error
// Fetch an initial DERP Map before we start serving // Fetch an initial DERP Map before we start serving
@ -501,7 +511,9 @@ func (h *Headscale) Serve() error {
return err return err
} }
if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
h.DERPMap.Regions[region.RegionID] = &region h.DERPMap.Regions[region.RegionID] = &region
}
go h.DERPServer.ServeSTUN() go h.DERPServer.ServeSTUN()
} }
@ -708,14 +720,16 @@ func (h *Headscale) Serve() error {
var tailsqlContext context.Context var tailsqlContext context.Context
if tailsqlEnabled { if tailsqlEnabled {
if h.cfg.DBtype != db.Sqlite { if h.cfg.Database.Type != types.DatabaseSqlite {
log.Fatal().Str("type", h.cfg.DBtype).Msgf("tailsql only support %q", db.Sqlite) log.Fatal().
Str("type", h.cfg.Database.Type).
Msgf("tailsql only support %q", types.DatabaseSqlite)
} }
if tailsqlTSKey == "" { if tailsqlTSKey == "" {
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set") log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
} }
tailsqlContext = context.Background() tailsqlContext = context.Background()
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.DBpath) go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path)
} }
// Handle common process-killing signals so we can gracefully shut down: // Handle common process-killing signals so we can gracefully shut down:
@ -751,7 +765,8 @@ func (h *Headscale) Serve() error {
Str("path", aclPath). Str("path", aclPath).
Msg("ACL policy successfully reloaded, notifying nodes of change") Msg("ACL policy successfully reloaded, notifying nodes of change")
h.nodeNotifier.NotifyAll(types.StateUpdate{ ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate, Type: types.StateFullUpdate,
}) })
} }

View file

@ -1,6 +1,7 @@
package hscontrol package hscontrol
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -8,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -199,6 +201,19 @@ func (h *Headscale) handleRegister(
return return
} }
// When logged out and reauthenticating with OIDC, the OldNodeKey is not passed, but the NodeKey has changed
if node.NodeKey.String() != registerRequest.NodeKey.String() &&
registerRequest.OldNodeKey.IsZero() && !node.IsExpired() {
h.handleNodeKeyRefresh(
writer,
registerRequest,
*node,
machineKey,
)
return
}
if registerRequest.Followup != "" { if registerRequest.Followup != "" {
select { select {
case <-req.Context().Done(): case <-req.Context().Done():
@ -230,8 +245,6 @@ func (h *Headscale) handleRegister(
// handleAuthKey contains the logic to manage auth key client registration // handleAuthKey contains the logic to manage auth key client registration
// When using Noise, the machineKey is Zero. // When using Noise, the machineKey is Zero.
//
// TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKey( func (h *Headscale) handleAuthKey(
writer http.ResponseWriter, writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
@ -298,6 +311,9 @@ func (h *Headscale) handleAuthKey(
nodeKey := registerRequest.NodeKey nodeKey := registerRequest.NodeKey
var update types.StateUpdate
var mkey key.MachinePublic
// retrieve node information if it exist // retrieve node information if it exist
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new node and we will move // exist, then this is a new node and we will move
@ -311,7 +327,7 @@ func (h *Headscale) handleAuthKey(
node.NodeKey = nodeKey node.NodeKey = nodeKey
node.AuthKeyID = uint(pak.ID) node.AuthKeyID = uint(pak.ID)
err := h.db.NodeSetExpiry(node, registerRequest.Expiry) err := h.db.NodeSetExpiry(node.ID, registerRequest.Expiry)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -322,10 +338,13 @@ func (h *Headscale) handleAuthKey(
return return
} }
mkey = node.MachineKey
update = types.StateUpdateExpire(node.ID, registerRequest.Expiry)
aclTags := pak.Proto().GetAclTags() aclTags := pak.Proto().GetAclTags()
if len(aclTags) > 0 { if len(aclTags) > 0 {
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
err = h.db.SetTags(node, aclTags) err = h.db.SetTags(node.ID, aclTags)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -357,6 +376,7 @@ func (h *Headscale) handleAuthKey(
Hostname: registerRequest.Hostinfo.Hostname, Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName, GivenName: givenName,
UserID: pak.User.ID, UserID: pak.User.ID,
User: pak.User,
MachineKey: machineKey, MachineKey: machineKey,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry, Expiry: &registerRequest.Expiry,
@ -380,9 +400,18 @@ func (h *Headscale) handleAuthKey(
return return
} }
mkey = node.MachineKey
update = types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from auth.handleAuthKey",
}
} }
err = h.db.UsePreAuthKey(pak) err = h.db.DB.Transaction(func(tx *gorm.DB) error {
return db.UsePreAuthKey(tx, pak)
})
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -424,6 +453,13 @@ func (h *Headscale) handleAuthKey(
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write response")
return
}
// TODO(kradalby): if notifying after register make sense.
if update.Valid() {
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String())
} }
log.Info(). log.Info().
@ -489,7 +525,7 @@ func (h *Headscale) handleNodeLogOut(
Msg("Client requested logout") Msg("Client requested logout")
now := time.Now() now := time.Now()
err := h.db.NodeSetExpiry(&node, now) err := h.db.NodeSetExpiry(node.ID, now)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -500,17 +536,10 @@ func (h *Headscale) handleNodeLogOut(
return return
} }
stateUpdate := types.StateUpdate{ stateUpdate := types.StateUpdateExpire(node.ID, now)
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &now,
},
},
}
if stateUpdate.Valid() { if stateUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
} }
resp.AuthURL = "" resp.AuthURL = ""
@ -541,7 +570,7 @@ func (h *Headscale) handleNodeLogOut(
} }
if node.IsEphemeral() { if node.IsEphemeral() {
err = h.db.DeleteNode(&node) err = h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap())
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -549,6 +578,15 @@ func (h *Headscale) handleNodeLogOut(
Msg("Cannot delete ephemeral node from the database") Msg("Cannot delete ephemeral node from the database")
} }
stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
}
return return
} }
@ -620,7 +658,9 @@ func (h *Headscale) handleNodeKeyRefresh(
Str("node", node.Hostname). Str("node", node.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh") Msg("We have the OldNodeKey in the database. This is a key refresh")
err := h.db.NodeSetNodeKey(&node, registerRequest.NodeKey) err := h.db.DB.Transaction(func(tx *gorm.DB) error {
return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey)
})
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().

View file

@ -13,16 +13,23 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx" "go4.org/netipx"
"gorm.io/gorm"
) )
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) { func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) {
return getAvailableIPs(rx, hsdb.ipPrefixes)
})
}
func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) {
var ips types.NodeAddresses var ips types.NodeAddresses
var err error var err error
for _, ipPrefix := range hsdb.ipPrefixes { for _, ipPrefix := range ipPrefixes {
var ip *netip.Addr var ip *netip.Addr
ip, err = hsdb.getAvailableIP(ipPrefix) ip, err = getAvailableIP(rx, ipPrefix)
if err != nil { if err != nil {
return ips, err return ips, err
} }
@ -32,8 +39,8 @@ func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
return ips, err return ips, err
} }
func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) { func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) {
usedIps, err := hsdb.getUsedIPs() usedIps, err := getUsedIPs(rx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -64,12 +71,12 @@ func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, erro
} }
} }
func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) { func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) {
// FIXME: This really deserves a better data model, // FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough // but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet. // to begin experimenting with a dual stack tailnet.
var addressesSlices []string var addressesSlices []string
hsdb.db.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices) rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
for _, slice := range addressesSlices { for _, slice := range addressesSlices {

View file

@ -7,10 +7,16 @@ import (
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx" "go4.org/netipx"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm"
) )
func (s *Suite) TestGetAvailableIp(c *check.C) { func (s *Suite) TestGetAvailableIp(c *check.C) {
ips, err := db.getAvailableIPs() tx := db.DB.Begin()
defer tx.Rollback()
ips, err := getAvailableIPs(tx, []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -30,7 +36,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode") _, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
node := types.Node{ node := types.Node{
@ -41,10 +47,13 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
IPAddresses: ips, IPAddresses: ips,
} }
db.db.Save(&node) db.Write(func(tx *gorm.DB) error {
return tx.Save(&node).Error
usedIps, err := db.getUsedIPs() })
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
return getUsedIPs(rx)
})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1") expected := netip.MustParseAddr("10.27.0.1")
@ -63,19 +72,23 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
} }
func (s *Suite) TestGetMultiIp(c *check.C) { func (s *Suite) TestGetMultiIp(c *check.C) {
user, err := db.CreateUser("test-ip-multi") user, err := db.CreateUser("test-ip")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
ipPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
}
for index := 1; index <= 350; index++ { for index := 1; index <= 350; index++ {
db.ipAllocationMutex.Lock() tx := db.DB.Begin()
ips, err := db.getAvailableIPs() ips, err := getAvailableIPs(tx, ipPrefixes)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode") _, err = getNode(tx, "test", "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
node := types.Node{ node := types.Node{
@ -86,12 +99,13 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
IPAddresses: ips, IPAddresses: ips,
} }
db.db.Save(&node) tx.Save(&node)
c.Assert(tx.Commit().Error, check.IsNil)
db.ipAllocationMutex.Unlock()
} }
usedIps, err := db.getUsedIPs() usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
return getUsedIPs(rx)
})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
expected0 := netip.MustParseAddr("10.27.0.1") expected0 := netip.MustParseAddr("10.27.0.1")
@ -162,7 +176,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode") _, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
node := types.Node{ node := types.Node{
@ -172,7 +186,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(&node) db.DB.Save(&node)
ips2, err := db.getAvailableIPs() ips2, err := db.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)

View file

@ -22,9 +22,6 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
func (hsdb *HSDatabase) CreateAPIKey( func (hsdb *HSDatabase) CreateAPIKey(
expiration *time.Time, expiration *time.Time,
) (string, *types.APIKey, error) { ) (string, *types.APIKey, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@ -49,7 +46,7 @@ func (hsdb *HSDatabase) CreateAPIKey(
Expiration: expiration, Expiration: expiration,
} }
if err := hsdb.db.Save(&key).Error; err != nil { if err := hsdb.DB.Save(&key).Error; err != nil {
return "", nil, fmt.Errorf("failed to save API key to database: %w", err) return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
} }
@ -58,11 +55,8 @@ func (hsdb *HSDatabase) CreateAPIKey(
// ListAPIKeys returns the list of ApiKeys for a user. // ListAPIKeys returns the list of ApiKeys for a user.
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
keys := []types.APIKey{} keys := []types.APIKey{}
if err := hsdb.db.Find(&keys).Error; err != nil { if err := hsdb.DB.Find(&keys).Error; err != nil {
return nil, err return nil, err
} }
@ -71,11 +65,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
// GetAPIKey returns a ApiKey for a given key. // GetAPIKey returns a ApiKey for a given key.
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
key := types.APIKey{} key := types.APIKey{}
if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
@ -84,11 +75,8 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
// GetAPIKeyByID returns a ApiKey for a given id. // GetAPIKeyByID returns a ApiKey for a given id.
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
key := types.APIKey{} key := types.APIKey{}
if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
@ -98,10 +86,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
// does not exist. // does not exist.
func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
hsdb.mu.Lock() if result := hsdb.DB.Unscoped().Delete(key); result.Error != nil {
defer hsdb.mu.Unlock()
if result := hsdb.db.Unscoped().Delete(key); result.Error != nil {
return result.Error return result.Error
} }
@ -110,10 +95,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
// ExpireAPIKey marks a ApiKey as expired. // ExpireAPIKey marks a ApiKey as expired.
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
hsdb.mu.Lock() if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
defer hsdb.mu.Unlock()
if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
return err return err
} }
@ -121,9 +103,6 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
} }
func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
prefix, hash, found := strings.Cut(keyStr, ".") prefix, hash, found := strings.Cut(keyStr, ".")
if !found { if !found {
return false, ErrAPIKeyFailedToParse return false, ErrAPIKeyFailedToParse

View file

@ -6,24 +6,20 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/glebarez/sqlite" "github.com/glebarez/sqlite"
"github.com/go-gormigrate/gormigrate/v2" "github.com/go-gormigrate/gormigrate/v2"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
)
const ( "github.com/juanfont/headscale/hscontrol/notifier"
Postgres = "postgres" "github.com/juanfont/headscale/hscontrol/types"
Sqlite = "sqlite3" "github.com/juanfont/headscale/hscontrol/util"
) )
var errDatabaseNotSupported = errors.New("database type not supported") var errDatabaseNotSupported = errors.New("database type not supported")
@ -36,12 +32,7 @@ type KV struct {
} }
type HSDatabase struct { type HSDatabase struct {
db *gorm.DB DB *gorm.DB
notifier *notifier.Notifier
mu sync.RWMutex
ipAllocationMutex sync.Mutex
ipPrefixes []netip.Prefix ipPrefixes []netip.Prefix
baseDomain string baseDomain string
@ -50,18 +41,20 @@ type HSDatabase struct {
// TODO(kradalby): assemble this struct from toptions or something typed // TODO(kradalby): assemble this struct from toptions or something typed
// rather than arguments. // rather than arguments.
func NewHeadscaleDatabase( func NewHeadscaleDatabase(
dbType, connectionAddr string, cfg types.DatabaseConfig,
debug bool,
notifier *notifier.Notifier, notifier *notifier.Notifier,
ipPrefixes []netip.Prefix, ipPrefixes []netip.Prefix,
baseDomain string, baseDomain string,
) (*HSDatabase, error) { ) (*HSDatabase, error) {
dbConn, err := openDB(dbType, connectionAddr, debug) dbConn, err := openDB(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
migrations := gormigrate.New(dbConn, gormigrate.DefaultOptions, []*gormigrate.Migration{ migrations := gormigrate.New(
dbConn,
gormigrate.DefaultOptions,
[]*gormigrate.Migration{
// New migrations should be added as transactions at the end of this list. // New migrations should be added as transactions at the end of this list.
// The initial commit here is quite messy, completely out of order and // The initial commit here is quite messy, completely out of order and
// has no versioning and is the tech debt of not having versioned migrations // has no versioning and is the tech debt of not having versioned migrations
@ -70,7 +63,7 @@ func NewHeadscaleDatabase(
{ {
ID: "202312101416", ID: "202312101416",
Migrate: func(tx *gorm.DB) error { Migrate: func(tx *gorm.DB) error {
if dbType == Postgres { if cfg.Type == types.DatabasePostgres {
tx.Exec(`create extension if not exists "uuid-ossp";`) tx.Exec(`create extension if not exists "uuid-ossp";`)
} }
@ -78,23 +71,28 @@ func NewHeadscaleDatabase(
// the big rename from Machine to Node // the big rename from Machine to Node
_ = tx.Migrator().RenameTable("machines", "nodes") _ = tx.Migrator().RenameTable("machines", "nodes")
_ = tx.Migrator().RenameColumn(&types.Route{}, "machine_id", "node_id") _ = tx.Migrator().
RenameColumn(&types.Route{}, "machine_id", "node_id")
err = tx.AutoMigrate(types.User{}) err = tx.AutoMigrate(types.User{})
if err != nil { if err != nil {
return err return err
} }
_ = tx.Migrator().RenameColumn(&types.Node{}, "namespace_id", "user_id") _ = tx.Migrator().
_ = tx.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id") RenameColumn(&types.Node{}, "namespace_id", "user_id")
_ = tx.Migrator().
RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
_ = tx.Migrator().RenameColumn(&types.Node{}, "ip_address", "ip_addresses") _ = tx.Migrator().
RenameColumn(&types.Node{}, "ip_address", "ip_addresses")
_ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname") _ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname")
// GivenName is used as the primary source of DNS names, make sure // GivenName is used as the primary source of DNS names, make sure
// the field is populated and normalized if it was not when the // the field is populated and normalized if it was not when the
// node was registered. // node was registered.
_ = tx.Migrator().RenameColumn(&types.Node{}, "nickname", "given_name") _ = tx.Migrator().
RenameColumn(&types.Node{}, "nickname", "given_name")
// If the Node table has a column for registered, // If the Node table has a column for registered,
// find all occourences of "false" and drop them. Then // find all occourences of "false" and drop them. Then
@ -147,7 +145,9 @@ func NewHeadscaleDatabase(
DiscoKey string DiscoKey string
} }
var results []result var results []result
err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").
Find(&results).
Error
if err != nil { if err != nil {
return err return err
} }
@ -180,7 +180,8 @@ func NewHeadscaleDatabase(
} }
if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") { if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") {
log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...") log.Info().
Msgf("Database has legacy enabled_routes column in node, migrating...")
type NodeAux struct { type NodeAux struct {
ID uint64 ID uint64
@ -188,7 +189,10 @@ func NewHeadscaleDatabase(
} }
nodesAux := []NodeAux{} nodesAux := []NodeAux{}
err := tx.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error err := tx.Table("nodes").
Select("id, enabled_routes").
Scan(&nodesAux).
Error
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Error accessing db") log.Fatal().Err(err).Msg("Error accessing db")
} }
@ -234,7 +238,9 @@ func NewHeadscaleDatabase(
err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes") err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes")
if err != nil { if err != nil {
log.Error().Err(err).Msg("Error dropping enabled_routes column") log.Error().
Err(err).
Msg("Error dropping enabled_routes column")
} }
} }
@ -310,15 +316,15 @@ func NewHeadscaleDatabase(
return nil return nil
}, },
}, },
}) },
)
if err = migrations.Migrate(); err != nil { if err = migrations.Migrate(); err != nil {
log.Fatal().Err(err).Msgf("Migration failed: %v", err) log.Fatal().Err(err).Msgf("Migration failed: %v", err)
} }
db := HSDatabase{ db := HSDatabase{
db: dbConn, DB: dbConn,
notifier: notifier,
ipPrefixes: ipPrefixes, ipPrefixes: ipPrefixes,
baseDomain: baseDomain, baseDomain: baseDomain,
@ -327,20 +333,19 @@ func NewHeadscaleDatabase(
return &db, err return &db, err
} }
func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database") // TODO(kradalby): Integrate this with zerolog
var dbLogger logger.Interface var dbLogger logger.Interface
if debug { if cfg.Debug {
dbLogger = logger.Default dbLogger = logger.Default
} else { } else {
dbLogger = logger.Default.LogMode(logger.Silent) dbLogger = logger.Default.LogMode(logger.Silent)
} }
switch dbType { switch cfg.Type {
case Sqlite: case types.DatabaseSqlite:
db, err := gorm.Open( db, err := gorm.Open(
sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"), sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
&gorm.Config{ &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, DisableForeignKeyConstraintWhenMigrating: true,
Logger: dbLogger, Logger: dbLogger,
@ -359,16 +364,51 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
return db, err return db, err
case Postgres: case types.DatabasePostgres:
return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{ dbString := fmt.Sprintf(
"host=%s dbname=%s user=%s",
cfg.Postgres.Host,
cfg.Postgres.Name,
cfg.Postgres.User,
)
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
if !sslEnabled {
dbString += " sslmode=disable"
}
} else {
dbString += fmt.Sprintf(" sslmode=%s", cfg.Postgres.Ssl)
}
if cfg.Postgres.Port != 0 {
dbString += fmt.Sprintf(" port=%d", cfg.Postgres.Port)
}
if cfg.Postgres.Pass != "" {
dbString += fmt.Sprintf(" password=%s", cfg.Postgres.Pass)
}
db, err := gorm.Open(postgres.Open(dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, DisableForeignKeyConstraintWhenMigrating: true,
Logger: dbLogger, Logger: dbLogger,
}) })
if err != nil {
return nil, err
}
sqlDB, _ := db.DB()
sqlDB.SetMaxIdleConns(cfg.Postgres.MaxIdleConnections)
sqlDB.SetMaxOpenConns(cfg.Postgres.MaxOpenConnections)
sqlDB.SetConnMaxIdleTime(
time.Duration(cfg.Postgres.ConnMaxIdleTimeSecs) * time.Second,
)
return db, nil
} }
return nil, fmt.Errorf( return nil, fmt.Errorf(
"database of type %s is not supported: %w", "database of type %s is not supported: %w",
dbType, cfg.Type,
errDatabaseNotSupported, errDatabaseNotSupported,
) )
} }
@ -376,7 +416,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
func (hsdb *HSDatabase) PingDB(ctx context.Context) error { func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, time.Second) ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel() defer cancel()
sqlDB, err := hsdb.db.DB() sqlDB, err := hsdb.DB.DB()
if err != nil { if err != nil {
return err return err
} }
@ -385,10 +425,48 @@ func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
} }
func (hsdb *HSDatabase) Close() error { func (hsdb *HSDatabase) Close() error {
db, err := hsdb.db.DB() db, err := hsdb.DB.DB()
if err != nil { if err != nil {
return err return err
} }
return db.Close() return db.Close()
} }
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
rx := hsdb.DB.Begin()
defer rx.Rollback()
return fn(rx)
}
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
rx := db.Begin()
defer rx.Rollback()
ret, err := fn(rx)
if err != nil {
var no T
return no, err
}
return ret, nil
}
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
tx := hsdb.DB.Begin()
defer tx.Rollback()
if err := fn(tx); err != nil {
return err
}
return tx.Commit().Error
}
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
tx := db.Begin()
defer tx.Rollback()
ret, err := fn(tx)
if err != nil {
var no T
return no, err
}
return ret, tx.Commit().Error
}

View file

@ -34,22 +34,21 @@ var (
) )
) )
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) { func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) {
hsdb.mu.RLock() return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
defer hsdb.mu.RUnlock() return ListPeers(rx, node)
})
return hsdb.listPeers(node)
} }
func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) { // ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("node", node.Hostname). Str("node", node.Hostname).
Msg("Finding direct peers") Msg("Finding direct peers")
nodes := types.Nodes{} nodes := types.Nodes{}
if err := hsdb.db. if err := tx.
Preload("AuthKey"). Preload("AuthKey").
Preload("AuthKey.User"). Preload("AuthKey.User").
Preload("User"). Preload("User").
@ -64,16 +63,15 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
return nodes, nil return nodes, nil
} }
func (hsdb *HSDatabase) ListNodes() ([]types.Node, error) { func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) {
hsdb.mu.RLock() return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
defer hsdb.mu.RUnlock() return ListNodes(rx)
})
return hsdb.listNodes()
} }
func (hsdb *HSDatabase) listNodes() ([]types.Node, error) { func ListNodes(tx *gorm.DB) (types.Nodes, error) {
nodes := []types.Node{} nodes := types.Nodes{}
if err := hsdb.db. if err := tx.
Preload("AuthKey"). Preload("AuthKey").
Preload("AuthKey.User"). Preload("AuthKey.User").
Preload("User"). Preload("User").
@ -85,16 +83,9 @@ func (hsdb *HSDatabase) listNodes() ([]types.Node, error) {
return nodes, nil return nodes, nil
} }
func (hsdb *HSDatabase) ListNodesByGivenName(givenName string) (types.Nodes, error) { func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listNodesByGivenName(givenName)
}
func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, error) {
nodes := types.Nodes{} nodes := types.Nodes{}
if err := hsdb.db. if err := tx.
Preload("AuthKey"). Preload("AuthKey").
Preload("AuthKey.User"). Preload("AuthKey.User").
Preload("User"). Preload("User").
@ -106,12 +97,15 @@ func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, err
return nodes, nil return nodes, nil
} }
// GetNode finds a Node by name and user and returns the Node struct. func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
hsdb.mu.RLock() return getNode(rx, user, name)
defer hsdb.mu.RUnlock() })
}
nodes, err := hsdb.ListNodesByUser(user) // getNode finds a Node by name and user and returns the Node struct.
func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
nodes, err := ListNodesByUser(tx, user)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -125,34 +119,16 @@ func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) {
return nil, ErrNodeNotFound return nil, ErrNodeNotFound
} }
// GetNodeByGivenName finds a Node by given name and user and returns the Node struct. func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
func (hsdb *HSDatabase) GetNodeByGivenName( return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
user string, return GetNodeByID(rx, id)
givenName string, })
) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
node := types.Node{}
if err := hsdb.db.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Where("given_name = ?", givenName).First(&node).Error; err != nil {
return nil, err
}
return nil, ErrNodeNotFound
} }
// GetNodeByID finds a Node by ID and returns the Node struct. // GetNodeByID finds a Node by ID and returns the Node struct.
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
mach := types.Node{} mach := types.Node{}
if result := hsdb.db. if result := tx.
Preload("AuthKey"). Preload("AuthKey").
Preload("AuthKey.User"). Preload("AuthKey.User").
Preload("User"). Preload("User").
@ -164,21 +140,19 @@ func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
return &mach, nil return &mach, nil
} }
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
func (hsdb *HSDatabase) GetNodeByMachineKey( return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
machineKey key.MachinePublic, return GetNodeByMachineKey(rx, machineKey)
) (*types.Node, error) { })
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getNodeByMachineKey(machineKey)
} }
func (hsdb *HSDatabase) getNodeByMachineKey( // GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
func GetNodeByMachineKey(
tx *gorm.DB,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) (*types.Node, error) { ) (*types.Node, error) {
mach := types.Node{} mach := types.Node{}
if result := hsdb.db. if result := tx.
Preload("AuthKey"). Preload("AuthKey").
Preload("AuthKey.User"). Preload("AuthKey.User").
Preload("User"). Preload("User").
@ -190,36 +164,24 @@ func (hsdb *HSDatabase) getNodeByMachineKey(
return &mach, nil return &mach, nil
} }
// GetNodeByNodeKey finds a Node by its current NodeKey. func (hsdb *HSDatabase) GetNodeByAnyKey(
func (hsdb *HSDatabase) GetNodeByNodeKey( machineKey key.MachinePublic,
nodeKey key.NodePublic, nodeKey key.NodePublic,
oldNodeKey key.NodePublic,
) (*types.Node, error) { ) (*types.Node, error) {
hsdb.mu.RLock() return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
defer hsdb.mu.RUnlock() return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey)
})
node := types.Node{}
if result := hsdb.db.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
First(&node, "node_key = ?",
nodeKey.String()); result.Error != nil {
return nil, result.Error
}
return &node, nil
} }
// GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct. // GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct.
func (hsdb *HSDatabase) GetNodeByAnyKey( // TODO(kradalby): see if we can remove this.
func GetNodeByAnyKey(
tx *gorm.DB,
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
) (*types.Node, error) { ) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
node := types.Node{} node := types.Node{}
if result := hsdb.db. if result := tx.
Preload("AuthKey"). Preload("AuthKey").
Preload("AuthKey.User"). Preload("AuthKey.User").
Preload("User"). Preload("User").
@ -234,60 +196,44 @@ func (hsdb *HSDatabase) GetNodeByAnyKey(
return &node, nil return &node, nil
} }
func (hsdb *HSDatabase) NodeReloadFromDatabase(node *types.Node) error { func (hsdb *HSDatabase) SetTags(
hsdb.mu.RLock() nodeID uint64,
defer hsdb.mu.RUnlock() tags []string,
) error {
if result := hsdb.db.Find(node).First(&node); result.Error != nil { return hsdb.Write(func(tx *gorm.DB) error {
return result.Error return SetTags(tx, nodeID, tags)
} })
return nil
} }
// SetTags takes a Node struct pointer and update the forced tags. // SetTags takes a Node struct pointer and update the forced tags.
func (hsdb *HSDatabase) SetTags( func SetTags(
node *types.Node, tx *gorm.DB,
nodeID uint64,
tags []string, tags []string,
) error { ) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if len(tags) == 0 { if len(tags) == 0 {
return nil return nil
} }
newTags := []string{} newTags := types.StringList{}
for _, tag := range tags { for _, tag := range tags {
if !util.StringOrPrefixListContains(newTags, tag) { if !util.StringOrPrefixListContains(newTags, tag) {
newTags = append(newTags, tag) newTags = append(newTags, tag)
} }
} }
if err := hsdb.db.Model(node).Updates(types.Node{ if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
ForcedTags: newTags,
}).Error; err != nil {
return fmt.Errorf("failed to update tags for node in the database: %w", err) return fmt.Errorf("failed to update tags for node in the database: %w", err)
} }
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from db.SetTags",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
return nil return nil
} }
// RenameNode takes a Node struct and a new GivenName for the nodes // RenameNode takes a Node struct and a new GivenName for the nodes
// and renames it. // and renames it.
func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error { func RenameNode(tx *gorm.DB,
hsdb.mu.Lock() nodeID uint64, newName string,
defer hsdb.mu.Unlock() ) error {
err := util.CheckForFQDNRules( err := util.CheckForFQDNRules(
newName, newName,
) )
@ -295,129 +241,74 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
log.Error(). log.Error().
Caller(). Caller().
Str("func", "RenameNode"). Str("func", "RenameNode").
Str("node", node.Hostname). Uint64("nodeID", nodeID).
Str("newName", newName). Str("newName", newName).
Err(err). Err(err).
Msg("failed to rename node") Msg("failed to rename node")
return err return err
} }
node.GivenName = newName
if err := hsdb.db.Model(node).Updates(types.Node{ if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
GivenName: newName,
}).Error; err != nil {
return fmt.Errorf("failed to rename node in the database: %w", err) return fmt.Errorf("failed to rename node in the database: %w", err)
} }
stateUpdate := types.StateUpdate{ return nil
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from db.RenameNode",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
} }
return nil func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error {
return hsdb.Write(func(tx *gorm.DB) error {
return NodeSetExpiry(tx, nodeID, expiry)
})
} }
// NodeSetExpiry takes a Node struct and a new expiry time. // NodeSetExpiry takes a Node struct and a new expiry time.
func (hsdb *HSDatabase) NodeSetExpiry(node *types.Node, expiry time.Time) error { func NodeSetExpiry(tx *gorm.DB,
hsdb.mu.Lock() nodeID uint64, expiry time.Time,
defer hsdb.mu.Unlock() ) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
return hsdb.nodeSetExpiry(node, expiry)
} }
func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error { func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.MachinePublic]bool) error {
if err := hsdb.db.Model(node).Updates(types.Node{ return hsdb.Write(func(tx *gorm.DB) error {
Expiry: &expiry, return DeleteNode(tx, node, isConnected)
}).Error; err != nil { })
return fmt.Errorf(
"failed to refresh node (update expiration) in the database: %w",
err,
)
}
node.Expiry = &expiry
stateSelfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if stateSelfUpdate.Valid() {
hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &expiry,
},
},
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
return nil
} }
// DeleteNode deletes a Node from the database. // DeleteNode deletes a Node from the database.
func (hsdb *HSDatabase) DeleteNode(node *types.Node) error { // Caller is responsible for notifying all of change.
hsdb.mu.Lock() func DeleteNode(tx *gorm.DB,
defer hsdb.mu.Unlock() node *types.Node,
isConnected map[key.MachinePublic]bool,
return hsdb.deleteNode(node) ) error {
} err := deleteNodeRoutes(tx, node, map[key.MachinePublic]bool{})
func (hsdb *HSDatabase) deleteNode(node *types.Node) error {
err := hsdb.deleteNodeRoutes(node)
if err != nil { if err != nil {
return err return err
} }
// Unscoped causes the node to be fully removed from the database. // Unscoped causes the node to be fully removed from the database.
if err := hsdb.db.Unscoped().Delete(&node).Error; err != nil { if err := tx.Unscoped().Delete(&node).Error; err != nil {
return err return err
} }
stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil return nil
} }
// UpdateLastSeen sets a node's last seen field indicating that we // UpdateLastSeen sets a node's last seen field indicating that we
// have recently communicating with this node. // have recently communicating with this node.
// This is mostly used to indicate if a node is online and is not func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error {
// extremely important to make sure is fully correct and to avoid return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
// holding up the hot path, does not contain any locks and isnt
// concurrency safe. But that should be ok.
func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error {
return hsdb.db.Model(node).Updates(types.Node{
LastSeen: node.LastSeen,
}).Error
} }
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( func RegisterNodeFromAuthCallback(
tx *gorm.DB,
cache *cache.Cache, cache *cache.Cache,
mkey key.MachinePublic, mkey key.MachinePublic,
userName string, userName string,
nodeExpiry *time.Time, nodeExpiry *time.Time,
registrationMethod string, registrationMethod string,
ipPrefixes []netip.Prefix,
) (*types.Node, error) { ) (*types.Node, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
log.Debug(). log.Debug().
Str("machine_key", mkey.ShortString()). Str("machine_key", mkey.ShortString()).
Str("userName", userName). Str("userName", userName).
@ -427,7 +318,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
if nodeInterface, ok := cache.Get(mkey.String()); ok { if nodeInterface, ok := cache.Get(mkey.String()); ok {
if registrationNode, ok := nodeInterface.(types.Node); ok { if registrationNode, ok := nodeInterface.(types.Node); ok {
user, err := hsdb.getUser(userName) user, err := GetUser(tx, userName)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"failed to find user in register node from auth callback, %w", "failed to find user in register node from auth callback, %w",
@ -442,14 +333,17 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
} }
registrationNode.UserID = user.ID registrationNode.UserID = user.ID
registrationNode.User = *user
registrationNode.RegisterMethod = registrationMethod registrationNode.RegisterMethod = registrationMethod
if nodeExpiry != nil { if nodeExpiry != nil {
registrationNode.Expiry = nodeExpiry registrationNode.Expiry = nodeExpiry
} }
node, err := hsdb.registerNode( node, err := RegisterNode(
tx,
registrationNode, registrationNode,
ipPrefixes,
) )
if err == nil { if err == nil {
@ -465,15 +359,14 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
return nil, ErrNodeNotFoundRegistrationCache return nil, ErrNodeNotFoundRegistrationCache
} }
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) { func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
hsdb.mu.Lock() return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
defer hsdb.mu.Unlock() return RegisterNode(tx, node, hsdb.ipPrefixes)
})
return hsdb.registerNode(node)
} }
func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { // RegisterNode is executed from the CLI to register a new Node using its MachineKey.
func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) {
log.Debug(). log.Debug().
Str("node", node.Hostname). Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()). Str("machine_key", node.MachineKey.ShortString()).
@ -485,7 +378,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
// so we store the node.Expire and node.Nodekey that has been set when // so we store the node.Expire and node.Nodekey that has been set when
// adding it to the registrationCache // adding it to the registrationCache
if len(node.IPAddresses) > 0 { if len(node.IPAddresses) > 0 {
if err := hsdb.db.Save(&node).Error; err != nil { if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register existing node in the database: %w", err) return nil, fmt.Errorf("failed register existing node in the database: %w", err)
} }
@ -500,10 +393,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
return &node, nil return &node, nil
} }
hsdb.ipAllocationMutex.Lock() ips, err := getAvailableIPs(tx, ipPrefixes)
defer hsdb.ipAllocationMutex.Unlock()
ips, err := hsdb.getAvailableIPs()
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -516,7 +406,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
node.IPAddresses = ips node.IPAddresses = ips
if err := hsdb.db.Save(&node).Error; err != nil { if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register(save) node in the database: %w", err) return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
} }
@ -530,61 +420,50 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
} }
// NodeSetNodeKey sets the node key of a node and saves it to the database. // NodeSetNodeKey sets the node key of a node and saves it to the database.
func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic) error { func NodeSetNodeKey(tx *gorm.DB, node *types.Node, nodeKey key.NodePublic) error {
hsdb.mu.Lock() return tx.Model(node).Updates(types.Node{
defer hsdb.mu.Unlock()
if err := hsdb.db.Model(node).Updates(types.Node{
NodeKey: nodeKey, NodeKey: nodeKey,
}).Error; err != nil { }).Error
return err
} }
return nil
}
// NodeSetMachineKey sets the node key of a node and saves it to the database.
func (hsdb *HSDatabase) NodeSetMachineKey( func (hsdb *HSDatabase) NodeSetMachineKey(
node *types.Node, node *types.Node,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) error { ) error {
hsdb.mu.Lock() return hsdb.Write(func(tx *gorm.DB) error {
defer hsdb.mu.Unlock() return NodeSetMachineKey(tx, node, machineKey)
})
if err := hsdb.db.Model(node).Updates(types.Node{
MachineKey: machineKey,
}).Error; err != nil {
return err
} }
return nil // NodeSetMachineKey sets the node key of a node and saves it to the database.
func NodeSetMachineKey(
tx *gorm.DB,
node *types.Node,
machineKey key.MachinePublic,
) error {
return tx.Model(node).Updates(types.Node{
MachineKey: machineKey,
}).Error
} }
// NodeSave saves a node object to the database, prefer to use a specific save method rather // NodeSave saves a node object to the database, prefer to use a specific save method rather
// than this. It is intended to be used when we are changing or. // than this. It is intended to be used when we are changing or.
func (hsdb *HSDatabase) NodeSave(node *types.Node) error { // TODO(kradalby): Remove this func, just use Save.
hsdb.mu.Lock() func NodeSave(tx *gorm.DB, node *types.Node) error {
defer hsdb.mu.Unlock() return tx.Save(node).Error
if err := hsdb.db.Save(node).Error; err != nil {
return err
} }
return nil func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
return GetAdvertisedRoutes(rx, node)
})
} }
// GetAdvertisedRoutes returns the routes that are be advertised by the given node. // GetAdvertisedRoutes returns the routes that are be advertised by the given node.
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getAdvertisedRoutes(node)
}
func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
routes := types.Routes{} routes := types.Routes{}
err := hsdb.db. err := tx.
Preload("Node"). Preload("Node").
Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
@ -605,18 +484,17 @@ func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, e
return prefixes, nil return prefixes, nil
} }
// GetEnabledRoutes returns the routes that are enabled for the node.
func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
hsdb.mu.RLock() return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
defer hsdb.mu.RUnlock() return GetEnabledRoutes(rx, node)
})
return hsdb.getEnabledRoutes(node)
} }
func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { // GetEnabledRoutes returns the routes that are enabled for the node.
func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
routes := types.Routes{} routes := types.Routes{}
err := hsdb.db. err := tx.
Preload("Node"). Preload("Node").
Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true). Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true).
Find(&routes).Error Find(&routes).Error
@ -638,16 +516,13 @@ func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, erro
return prefixes, nil return prefixes, nil
} }
func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool { func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
route, err := netip.ParsePrefix(routeStr) route, err := netip.ParsePrefix(routeStr)
if err != nil { if err != nil {
return false return false
} }
enabledRoutes, err := hsdb.getEnabledRoutes(node) enabledRoutes, err := GetEnabledRoutes(tx, node)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Could not get enabled routes") log.Error().Err(err).Msg("Could not get enabled routes")
@ -663,26 +538,37 @@ func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool
return false return false
} }
func (hsdb *HSDatabase) enableRoutes(
node *types.Node,
routeStrs ...string,
) (*types.StateUpdate, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return enableRoutes(tx, node, routeStrs...)
})
}
// enableRoutes enables new routes based on a list of new routes. // enableRoutes enables new routes based on a list of new routes.
func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error { func enableRoutes(tx *gorm.DB,
node *types.Node, routeStrs ...string,
) (*types.StateUpdate, error) {
newRoutes := make([]netip.Prefix, len(routeStrs)) newRoutes := make([]netip.Prefix, len(routeStrs))
for index, routeStr := range routeStrs { for index, routeStr := range routeStrs {
route, err := netip.ParsePrefix(routeStr) route, err := netip.ParsePrefix(routeStr)
if err != nil { if err != nil {
return err return nil, err
} }
newRoutes[index] = route newRoutes[index] = route
} }
advertisedRoutes, err := hsdb.getAdvertisedRoutes(node) advertisedRoutes, err := GetAdvertisedRoutes(tx, node)
if err != nil { if err != nil {
return err return nil, err
} }
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) { if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
return fmt.Errorf( return nil, fmt.Errorf(
"route (%s) is not available on node %s: %w", "route (%s) is not available on node %s: %w",
node.Hostname, node.Hostname,
newRoute, ErrNodeRouteIsNotAvailable, newRoute, ErrNodeRouteIsNotAvailable,
@ -693,7 +579,7 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
// Separate loop so we don't leave things in a half-updated state // Separate loop so we don't leave things in a half-updated state
for _, prefix := range newRoutes { for _, prefix := range newRoutes {
route := types.Route{} route := types.Route{}
err := hsdb.db.Preload("Node"). err := tx.Preload("Node").
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
First(&route).Error First(&route).Error
if err == nil { if err == nil {
@ -702,23 +588,23 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
// Mark already as primary if there is only this node offering this subnet // Mark already as primary if there is only this node offering this subnet
// (and is not an exit route) // (and is not an exit route)
if !route.IsExitRoute() { if !route.IsExitRoute() {
route.IsPrimary = hsdb.isUniquePrefix(route) route.IsPrimary = isUniquePrefix(tx, route)
} }
err = hsdb.db.Save(&route).Error err = tx.Save(&route).Error
if err != nil { if err != nil {
return fmt.Errorf("failed to enable route: %w", err) return nil, fmt.Errorf("failed to enable route: %w", err)
} }
} else { } else {
return fmt.Errorf("failed to find route: %w", err) return nil, fmt.Errorf("failed to find route: %w", err)
} }
} }
// Ensure the node has the latest routes when notifying the other // Ensure the node has the latest routes when notifying the other
// nodes // nodes
nRoutes, err := hsdb.getNodeRoutes(node) nRoutes, err := GetNodeRoutes(tx, node)
if err != nil { if err != nil {
return fmt.Errorf("failed to read back routes: %w", err) return nil, fmt.Errorf("failed to read back routes: %w", err)
} }
node.Routes = nRoutes node.Routes = nRoutes
@ -729,17 +615,11 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
Strs("routes", routeStrs). Strs("routes", routeStrs).
Msg("enabling routes") Msg("enabling routes")
stateUpdate := types.StateUpdate{ return &types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node}, ChangeNodes: types.Nodes{node},
Message: "called from db.enableRoutes", Message: "created in db.enableRoutes",
} }, nil
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(
stateUpdate, node.MachineKey.String())
}
return nil
} }
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
@ -772,16 +652,23 @@ func (hsdb *HSDatabase) GenerateGivenName(
mkey key.MachinePublic, mkey key.MachinePublic,
suppliedName string, suppliedName string,
) (string, error) { ) (string, error) {
hsdb.mu.RLock() return Read(hsdb.DB, func(rx *gorm.DB) (string, error) {
defer hsdb.mu.RUnlock() return GenerateGivenName(rx, mkey, suppliedName)
})
}
func GenerateGivenName(
tx *gorm.DB,
mkey key.MachinePublic,
suppliedName string,
) (string, error) {
givenName, err := generateGivenName(suppliedName, false) givenName, err := generateGivenName(suppliedName, false)
if err != nil { if err != nil {
return "", err return "", err
} }
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
nodes, err := hsdb.listNodesByGivenName(givenName) nodes, err := listNodesByGivenName(tx, givenName)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -805,29 +692,28 @@ func (hsdb *HSDatabase) GenerateGivenName(
return givenName, nil return givenName, nil
} }
func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) { func ExpireEphemeralNodes(tx *gorm.DB,
hsdb.mu.Lock() inactivityThreshhold time.Duration,
defer hsdb.mu.Unlock() ) (types.StateUpdate, bool) {
users, err := ListUsers(tx)
users, err := hsdb.listUsers()
if err != nil { if err != nil {
log.Error().Err(err).Msg("Error listing users") log.Error().Err(err).Msg("Error listing users")
return return types.StateUpdate{}, false
} }
expired := make([]tailcfg.NodeID, 0)
for _, user := range users { for _, user := range users {
nodes, err := hsdb.listNodesByUser(user.Name) nodes, err := ListNodesByUser(tx, user.Name)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("user", user.Name). Str("user", user.Name).
Msg("Error listing nodes in user") Msg("Error listing nodes in user")
return return types.StateUpdate{}, false
} }
expired := make([]tailcfg.NodeID, 0)
for idx, node := range nodes { for idx, node := range nodes {
if node.IsEphemeral() && node.LastSeen != nil && if node.IsEphemeral() && node.LastSeen != nil &&
time.Now(). time.Now().
@ -838,7 +724,8 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration)
Str("node", node.Hostname). Str("node", node.Hostname).
Msg("Ephemeral client removed from database") Msg("Ephemeral client removed from database")
err = hsdb.deleteNode(nodes[idx]) // empty isConnected map as ephemeral nodes are not routes
err = DeleteNode(tx, nodes[idx], map[key.MachinePublic]bool{})
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -848,33 +735,35 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration)
} }
} }
// TODO(kradalby): needs to be moved out of transaction
}
if len(expired) > 0 { if len(expired) > 0 {
hsdb.notifier.NotifyAll(types.StateUpdate{ return types.StateUpdate{
Type: types.StatePeerRemoved, Type: types.StatePeerRemoved,
Removed: expired, Removed: expired,
}) }, true
}
}
} }
func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { return types.StateUpdate{}, false
hsdb.mu.Lock() }
defer hsdb.mu.Unlock()
func ExpireExpiredNodes(tx *gorm.DB,
lastCheck time.Time,
) (time.Time, types.StateUpdate, bool) {
// use the time of the start of the function to ensure we // use the time of the start of the function to ensure we
// dont miss some nodes by returning it _after_ we have // dont miss some nodes by returning it _after_ we have
// checked everything. // checked everything.
started := time.Now() started := time.Now()
expiredNodes := make([]*types.Node, 0) expired := make([]*tailcfg.PeerChange, 0)
nodes, err := hsdb.listNodes() nodes, err := ListNodes(tx)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Msg("Error listing nodes to find expired nodes") Msg("Error listing nodes to find expired nodes")
return time.Unix(0, 0) return time.Unix(0, 0), types.StateUpdate{}, false
} }
for index, node := range nodes { for index, node := range nodes {
if node.IsExpired() && if node.IsExpired() &&
@ -882,13 +771,17 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
// It will notify about all nodes that has been expired. // It will notify about all nodes that has been expired.
// It should only notify about expired nodes since _last check_. // It should only notify about expired nodes since _last check_.
node.Expiry.After(lastCheck) { node.Expiry.After(lastCheck) {
expiredNodes = append(expiredNodes, &nodes[index]) expired = append(expired, &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: node.Expiry,
})
now := time.Now()
// Do not use setNodeExpiry as that has a notifier hook, which // Do not use setNodeExpiry as that has a notifier hook, which
// can cause a deadlock, we are updating all changed nodes later // can cause a deadlock, we are updating all changed nodes later
// and there is no point in notifiying twice. // and there is no point in notifiying twice.
if err := hsdb.db.Model(nodes[index]).Updates(types.Node{ if err := tx.Model(&nodes[index]).Updates(types.Node{
Expiry: &started, Expiry: &now,
}).Error; err != nil { }).Error; err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -904,33 +797,12 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
} }
} }
expired := make([]*tailcfg.PeerChange, len(expiredNodes)) if len(expired) > 0 {
for idx, node := range expiredNodes { return started, types.StateUpdate{
expired[idx] = &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &started,
}
}
// Inform the peers of a node with a lightweight update.
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch, Type: types.StatePeerChangedPatch,
ChangePatches: expired, ChangePatches: expired,
} }, true
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
} }
// Inform the node itself that it has expired. return started, types.StateUpdate{}, false
for _, node := range expiredNodes {
stateSelfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if stateSelfUpdate.Valid() {
hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
}
}
return started
} }

View file

@ -23,7 +23,7 @@ func (s *Suite) TestGetNode(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode") _, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -38,9 +38,9 @@ func (s *Suite) TestGetNode(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(node) db.DB.Save(node)
_, err = db.GetNode("test", "testnode") _, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
@ -66,40 +66,12 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(&node) db.DB.Save(&node)
_, err = db.GetNodeByID(0) _, err = db.GetNodeByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
func (s *Suite) TestGetNodeByNodeKey(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
_, err = db.GetNodeByNodeKey(nodeKey.Public())
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -124,7 +96,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(&node) db.DB.Save(&node)
_, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) _, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -146,12 +118,12 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(1), AuthKeyID: uint(1),
} }
db.db.Save(&node) db.DB.Save(&node)
err = db.DeleteNode(&node) err = db.DeleteNode(&node, map[key.MachinePublic]bool{})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode(user.Name, "testnode3") _, err = db.getNode(user.Name, "testnode3")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
@ -178,7 +150,7 @@ func (s *Suite) TestListPeers(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(&node) db.DB.Save(&node)
} }
node0ByID, err := db.GetNodeByID(0) node0ByID, err := db.GetNodeByID(0)
@ -228,7 +200,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(stor[index%2].key.ID), AuthKeyID: uint(stor[index%2].key.ID),
} }
db.db.Save(&node) db.DB.Save(&node)
} }
aclPolicy := &policy.ACLPolicy{ aclPolicy := &policy.ACLPolicy{
@ -295,7 +267,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode") _, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -311,16 +283,19 @@ func (s *Suite) TestExpireNode(c *check.C) {
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
Expiry: &time.Time{}, Expiry: &time.Time{},
} }
db.db.Save(node) db.DB.Save(node)
nodeFromDB, err := db.GetNode("test", "testnode") nodeFromDB, err := db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nodeFromDB, check.NotNil) c.Assert(nodeFromDB, check.NotNil)
c.Assert(nodeFromDB.IsExpired(), check.Equals, false) c.Assert(nodeFromDB.IsExpired(), check.Equals, false)
now := time.Now() now := time.Now()
err = db.NodeSetExpiry(nodeFromDB, now) err = db.NodeSetExpiry(nodeFromDB.ID, now)
c.Assert(err, check.IsNil)
nodeFromDB, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nodeFromDB.IsExpired(), check.Equals, true) c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
@ -354,7 +329,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("user-1", "testnode") _, err = db.getNode("user-1", "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -372,7 +347,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(node) db.DB.Save(node)
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2") givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
@ -397,7 +372,7 @@ func (s *Suite) TestSetTags(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode") _, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -412,21 +387,21 @@ func (s *Suite) TestSetTags(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(node) db.DB.Save(node)
// assign simple tags // assign simple tags
sTags := []string{"tag:test", "tag:foo"} sTags := []string{"tag:test", "tag:foo"}
err = db.SetTags(node, sTags) err = db.SetTags(node.ID, sTags)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node, err = db.GetNode("test", "testnode") node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags)) c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
// assign duplicat tags, expect no errors but no doubles in DB // assign duplicat tags, expect no errors but no doubles in DB
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
err = db.SetTags(node, eTags) err = db.SetTags(node.ID, eTags)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node, err = db.GetNode("test", "testnode") node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert( c.Assert(
node.ForcedTags, node.ForcedTags,
@ -601,7 +576,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
} }
db.db.Save(&node) db.DB.Save(&node)
sendUpdate, err := db.SaveNodeRoutes(&node) sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -610,7 +585,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
node0ByID, err := db.GetNodeByID(0) node0ByID, err := db.GetNodeByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = db.EnableAutoApprovedRoutes(pol, node0ByID) // TODO(kradalby): Check state update
_, err = db.EnableAutoApprovedRoutes(pol, node0ByID)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes, err := db.GetEnabledRoutes(node0ByID) enabledRoutes, err := db.GetEnabledRoutes(node0ByID)

View file

@ -20,7 +20,6 @@ var (
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
) )
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func (hsdb *HSDatabase) CreatePreAuthKey( func (hsdb *HSDatabase) CreatePreAuthKey(
userName string, userName string,
reusable bool, reusable bool,
@ -28,11 +27,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
expiration *time.Time, expiration *time.Time,
aclTags []string, aclTags []string,
) (*types.PreAuthKey, error) { ) (*types.PreAuthKey, error) {
// TODO(kradalby): figure out this lock return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
// hsdb.mu.Lock() return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags)
// defer hsdb.mu.Unlock() })
}
user, err := hsdb.GetUser(userName) // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func CreatePreAuthKey(
tx *gorm.DB,
userName string,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -48,7 +57,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
} }
now := time.Now().UTC() now := time.Now().UTC()
kstr, err := hsdb.generateKey() kstr, err := generateKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -63,9 +72,8 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
Expiration: expiration, Expiration: expiration,
} }
err = hsdb.db.Transaction(func(db *gorm.DB) error { if err := tx.Save(&key).Error; err != nil {
if err := db.Save(&key).Error; err != nil { return nil, fmt.Errorf("failed to create key in the database: %w", err)
return fmt.Errorf("failed to create key in the database: %w", err)
} }
if len(aclTags) > 0 { if len(aclTags) > 0 {
@ -73,8 +81,8 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
for _, tag := range aclTags { for _, tag := range aclTags {
if !seenTags[tag] { if !seenTags[tag] {
if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return fmt.Errorf( return nil, fmt.Errorf(
"failed to ceate key tag in the database: %w", "failed to ceate key tag in the database: %w",
err, err,
) )
@ -84,9 +92,6 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
} }
} }
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -94,22 +99,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
return &key, nil return &key, nil
} }
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
hsdb.mu.RLock() return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
defer hsdb.mu.RUnlock() return ListPreAuthKeys(rx, userName)
})
return hsdb.listPreAuthKeys(userName)
} }
func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) { // ListPreAuthKeys returns the list of PreAuthKeys for a user.
user, err := hsdb.getUser(userName) func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
keys := []types.PreAuthKey{} keys := []types.PreAuthKey{}
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { if err := tx.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
return nil, err return nil, err
} }
@ -117,11 +121,8 @@ func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, er
} }
// GetPreAuthKey returns a PreAuthKey for a given key. // GetPreAuthKey returns a PreAuthKey for a given key.
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
hsdb.mu.RLock() pak, err := ValidatePreAuthKey(tx, key)
defer hsdb.mu.RUnlock()
pak, err := hsdb.ValidatePreAuthKey(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -135,15 +136,8 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist. // does not exist.
func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
hsdb.mu.Lock() return tx.Transaction(func(db *gorm.DB) error {
defer hsdb.mu.Unlock()
return hsdb.destroyPreAuthKey(pak)
}
func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
return hsdb.db.Transaction(func(db *gorm.DB) error {
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
return result.Error return result.Error
} }
@ -156,12 +150,15 @@ func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
}) })
} }
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
hsdb.mu.Lock() return hsdb.Write(func(tx *gorm.DB) error {
defer hsdb.mu.Unlock() return ExpirePreAuthKey(tx, k)
})
}
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { // MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err return err
} }
@ -169,26 +166,26 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
} }
// UsePreAuthKey marks a PreAuthKey as used. // UsePreAuthKey marks a PreAuthKey as used.
func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
k.Used = true k.Used = true
if err := hsdb.db.Save(k).Error; err != nil { if err := tx.Save(k).Error; err != nil {
return fmt.Errorf("failed to update key used status in the database: %w", err) return fmt.Errorf("failed to update key used status in the database: %w", err)
} }
return nil return nil
} }
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return ValidatePreAuthKey(rx, k)
})
}
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used. // If returns no error and a PreAuthKey, it can be used.
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
pak := types.PreAuthKey{} pak := types.PreAuthKey{}
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( if result := tx.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
result.Error, result.Error,
gorm.ErrRecordNotFound, gorm.ErrRecordNotFound,
) { ) {
@ -204,7 +201,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
} }
nodes := types.Nodes{} nodes := types.Nodes{}
if err := hsdb.db. if err := tx.
Preload("AuthKey"). Preload("AuthKey").
Where(&types.Node{AuthKeyID: uint(pak.ID)}). Where(&types.Node{AuthKeyID: uint(pak.ID)}).
Find(&nodes).Error; err != nil { Find(&nodes).Error; err != nil {
@ -218,7 +215,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
return &pak, nil return &pak, nil
} }
func (hsdb *HSDatabase) generateKey() (string, error) { func generateKey() (string, error) {
size := 24 size := 24
bytes := make([]byte, size) bytes := make([]byte, size)
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(bytes); err != nil {

View file

@ -6,6 +6,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm"
) )
func (*Suite) TestCreatePreAuthKey(c *check.C) { func (*Suite) TestCreatePreAuthKey(c *check.C) {
@ -41,7 +42,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
user, err := db.CreateUser("test2") user, err := db.CreateUser("test2")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
now := time.Now() now := time.Now().Add(-5 * time.Second)
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -82,7 +83,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(&node) db.DB.Save(&node)
key, err := db.ValidatePreAuthKey(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
@ -103,7 +104,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(&node) db.DB.Save(&node)
key, err := db.ValidatePreAuthKey(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -138,19 +139,22 @@ func (*Suite) TestEphemeralKey(c *check.C) {
LastSeen: &now, LastSeen: &now,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(&node) db.DB.Save(&node)
_, err = db.ValidatePreAuthKey(pak.Key) _, err = db.ValidatePreAuthKey(pak.Key)
// Ephemeral keys are by definition reusable // Ephemeral keys are by definition reusable
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("test7", "testest") _, err = db.getNode("test7", "testest")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
db.ExpireEphemeralNodes(time.Second * 20) db.DB.Transaction(func(tx *gorm.DB) error {
ExpireEphemeralNodes(tx, time.Second*20)
return nil
})
// The machine record should have been deleted // The machine record should have been deleted
_, err = db.GetNode("test7", "testest") _, err = db.getNode("test7", "testest")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
@ -178,7 +182,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak.Used = true pak.Used = true
db.db.Save(&pak) db.DB.Save(&pak)
_, err = db.ValidatePreAuthKey(pak.Key) _, err = db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)

View file

@ -7,23 +7,15 @@ import (
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
var ErrRouteIsNotAvailable = errors.New("route is not available") var ErrRouteIsNotAvailable = errors.New("route is not available")
func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { func GetRoutes(tx *gorm.DB) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getRoutes()
}
func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
var routes types.Routes var routes types.Routes
err := hsdb.db. err := tx.
Preload("Node"). Preload("Node").
Preload("Node.User"). Preload("Node.User").
Find(&routes).Error Find(&routes).Error
@ -34,9 +26,9 @@ func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
return routes, nil return routes, nil
} }
func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) { func getAdvertisedAndEnabledRoutes(tx *gorm.DB) (types.Routes, error) {
var routes types.Routes var routes types.Routes
err := hsdb.db. err := tx.
Preload("Node"). Preload("Node").
Preload("Node.User"). Preload("Node.User").
Where("advertised = ? AND enabled = ?", true, true). Where("advertised = ? AND enabled = ?", true, true).
@ -48,9 +40,9 @@ func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
return routes, nil return routes, nil
} }
func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) { func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
var routes types.Routes var routes types.Routes
err := hsdb.db. err := tx.
Preload("Node"). Preload("Node").
Preload("Node.User"). Preload("Node.User").
Where("prefix = ?", types.IPPrefix(pref)). Where("prefix = ?", types.IPPrefix(pref)).
@ -62,16 +54,9 @@ func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, erro
return routes, nil return routes, nil
} }
func (hsdb *HSDatabase) GetNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) { func GetNodeAdvertisedRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getNodeAdvertisedRoutes(node)
}
func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) {
var routes types.Routes var routes types.Routes
err := hsdb.db. err := tx.
Preload("Node"). Preload("Node").
Preload("Node.User"). Preload("Node.User").
Where("node_id = ? AND advertised = true", node.ID). Where("node_id = ? AND advertised = true", node.ID).
@ -84,15 +69,14 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes,
} }
func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) { func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) {
hsdb.mu.RLock() return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
defer hsdb.mu.RUnlock() return GetNodeRoutes(rx, node)
})
return hsdb.getNodeRoutes(node)
} }
func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) { func GetNodeRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
var routes types.Routes var routes types.Routes
err := hsdb.db. err := tx.
Preload("Node"). Preload("Node").
Preload("Node.User"). Preload("Node.User").
Where("node_id = ?", node.ID). Where("node_id = ?", node.ID).
@ -104,16 +88,9 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
return routes, nil return routes, nil
} }
func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { func GetRoute(tx *gorm.DB, id uint64) (*types.Route, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getRoute(id)
}
func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
var route types.Route var route types.Route
err := hsdb.db. err := tx.
Preload("Node"). Preload("Node").
Preload("Node.User"). Preload("Node.User").
First(&route, id).Error First(&route, id).Error
@ -124,40 +101,34 @@ func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
return &route, nil return &route, nil
} }
func (hsdb *HSDatabase) EnableRoute(id uint64) error { func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
hsdb.mu.Lock() route, err := GetRoute(tx, id)
defer hsdb.mu.Unlock()
return hsdb.enableRoute(id)
}
func (hsdb *HSDatabase) enableRoute(id uint64) error {
route, err := hsdb.getRoute(id)
if err != nil { if err != nil {
return err return nil, err
} }
// Tailscale requires both IPv4 and IPv6 exit routes to // Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per // be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if route.IsExitRoute() { if route.IsExitRoute() {
return hsdb.enableRoutes( return enableRoutes(
tx,
&route.Node, &route.Node,
types.ExitRouteV4.String(), types.ExitRouteV4.String(),
types.ExitRouteV6.String(), types.ExitRouteV6.String(),
) )
} }
return hsdb.enableRoutes(&route.Node, netip.Prefix(route.Prefix).String()) return enableRoutes(tx, &route.Node, netip.Prefix(route.Prefix).String())
} }
func (hsdb *HSDatabase) DisableRoute(id uint64) error { func DisableRoute(tx *gorm.DB,
hsdb.mu.Lock() id uint64,
defer hsdb.mu.Unlock() isConnected map[key.MachinePublic]bool,
) (*types.StateUpdate, error) {
route, err := hsdb.getRoute(id) route, err := GetRoute(tx, id)
if err != nil { if err != nil {
return err return nil, err
} }
var routes types.Routes var routes types.Routes
@ -166,64 +137,79 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
// Tailscale requires both IPv4 and IPv6 exit routes to // Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per // be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
var update *types.StateUpdate
if !route.IsExitRoute() { if !route.IsExitRoute() {
err = hsdb.failoverRouteWithNotify(route) update, err = failoverRouteReturnUpdate(tx, isConnected, route)
if err != nil { if err != nil {
return err return nil, err
} }
route.Enabled = false route.Enabled = false
route.IsPrimary = false route.IsPrimary = false
err = hsdb.db.Save(route).Error err = tx.Save(route).Error
if err != nil { if err != nil {
return err return nil, err
} }
} else { } else {
routes, err = hsdb.getNodeRoutes(&node) routes, err = GetNodeRoutes(tx, &node)
if err != nil { if err != nil {
return err return nil, err
} }
for i := range routes { for i := range routes {
if routes[i].IsExitRoute() { if routes[i].IsExitRoute() {
routes[i].Enabled = false routes[i].Enabled = false
routes[i].IsPrimary = false routes[i].IsPrimary = false
err = hsdb.db.Save(&routes[i]).Error err = tx.Save(&routes[i]).Error
if err != nil { if err != nil {
return err return nil, err
} }
} }
} }
} }
if routes == nil { if routes == nil {
routes, err = hsdb.getNodeRoutes(&node) routes, err = GetNodeRoutes(tx, &node)
if err != nil { if err != nil {
return err return nil, err
} }
} }
node.Routes = routes node.Routes = routes
stateUpdate := types.StateUpdate{ // If update is empty, it means that one was not created
// by failover (as a failover was not necessary), create
// one and return to the caller.
if update == nil {
update = &types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{&node}, ChangeNodes: types.Nodes{
&node,
},
Message: "called from db.DisableRoute", Message: "called from db.DisableRoute",
} }
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
} }
return nil return update, nil
} }
func (hsdb *HSDatabase) DeleteRoute(id uint64) error { func (hsdb *HSDatabase) DeleteRoute(
hsdb.mu.Lock() id uint64,
defer hsdb.mu.Unlock() isConnected map[key.MachinePublic]bool,
) (*types.StateUpdate, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return DeleteRoute(tx, id, isConnected)
})
}
route, err := hsdb.getRoute(id) func DeleteRoute(
tx *gorm.DB,
id uint64,
isConnected map[key.MachinePublic]bool,
) (*types.StateUpdate, error) {
route, err := GetRoute(tx, id)
if err != nil { if err != nil {
return err return nil, err
} }
var routes types.Routes var routes types.Routes
@ -232,19 +218,20 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
// Tailscale requires both IPv4 and IPv6 exit routes to // Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per // be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
var update *types.StateUpdate
if !route.IsExitRoute() { if !route.IsExitRoute() {
err := hsdb.failoverRouteWithNotify(route) update, err = failoverRouteReturnUpdate(tx, isConnected, route)
if err != nil { if err != nil {
return nil return nil, nil
} }
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { if err := tx.Unscoped().Delete(&route).Error; err != nil {
return err return nil, err
} }
} else { } else {
routes, err := hsdb.getNodeRoutes(&node) routes, err := GetNodeRoutes(tx, &node)
if err != nil { if err != nil {
return err return nil, err
} }
routesToDelete := types.Routes{} routesToDelete := types.Routes{}
@ -254,56 +241,59 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
} }
} }
if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil { if err := tx.Unscoped().Delete(&routesToDelete).Error; err != nil {
return err return nil, err
} }
} }
// If update is empty, it means that one was not created
// by failover (as a failover was not necessary), create
// one and return to the caller.
if routes == nil { if routes == nil {
routes, err = hsdb.getNodeRoutes(&node) routes, err = GetNodeRoutes(tx, &node)
if err != nil { if err != nil {
return err return nil, err
} }
} }
node.Routes = routes node.Routes = routes
stateUpdate := types.StateUpdate{ if update == nil {
update = &types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{&node}, ChangeNodes: types.Nodes{
&node,
},
Message: "called from db.DeleteRoute", Message: "called from db.DeleteRoute",
} }
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
} }
return nil return update, nil
} }
func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error { func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error {
routes, err := hsdb.getNodeRoutes(node) routes, err := GetNodeRoutes(tx, node)
if err != nil { if err != nil {
return err return err
} }
for i := range routes { for i := range routes {
if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil { if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
return err return err
} }
// TODO(kradalby): This is a bit too aggressive, we could probably // TODO(kradalby): This is a bit too aggressive, we could probably
// figure out which routes needs to be failed over rather than all. // figure out which routes needs to be failed over rather than all.
hsdb.failoverRouteWithNotify(&routes[i]) failoverRouteReturnUpdate(tx, isConnected, &routes[i])
} }
return nil return nil
} }
// isUniquePrefix returns if there is another node providing the same route already. // isUniquePrefix returns if there is another node providing the same route already.
func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
var count int64 var count int64
hsdb.db. tx.Model(&types.Route{}).
Model(&types.Route{}).
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?", Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
route.Prefix, route.Prefix,
route.NodeID, route.NodeID,
@ -312,9 +302,9 @@ func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
return count == 0 return count == 0
} }
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) { func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
var route types.Route var route types.Route
err := hsdb.db. err := tx.
Preload("Node"). Preload("Node").
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true). Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
First(&route).Error First(&route).Error
@ -329,14 +319,17 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro
return &route, nil return &route, nil
} }
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetNodePrimaryRoutes(rx, node)
})
}
// getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
// Exit nodes are not considered for this, as they are never marked as Primary. // Exit nodes are not considered for this, as they are never marked as Primary.
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) { func GetNodePrimaryRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
var routes types.Routes var routes types.Routes
err := hsdb.db. err := tx.
Preload("Node"). Preload("Node").
Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true). Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true).
Find(&routes).Error Find(&routes).Error
@ -347,22 +340,21 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er
return routes, nil return routes, nil
} }
// SaveNodeRoutes takes a node and updates the database with
// the new routes.
// It returns a bool wheter an update should be sent as the
// saved route impacts nodes.
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) { func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
hsdb.mu.Lock() return Write(hsdb.DB, func(tx *gorm.DB) (bool, error) {
defer hsdb.mu.Unlock() return SaveNodeRoutes(tx, node)
})
return hsdb.saveNodeRoutes(node)
} }
func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { // SaveNodeRoutes takes a node and updates the database with
// the new routes.
// It returns a bool whether an update should be sent as the
// saved route impacts nodes.
func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
sendUpdate := false sendUpdate := false
currentRoutes := types.Routes{} currentRoutes := types.Routes{}
err := hsdb.db.Where("node_id = ?", node.ID).Find(&currentRoutes).Error err := tx.Where("node_id = ?", node.ID).Find(&currentRoutes).Error
if err != nil { if err != nil {
return sendUpdate, err return sendUpdate, err
} }
@ -382,7 +374,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
if !route.Advertised { if !route.Advertised {
currentRoutes[pos].Advertised = true currentRoutes[pos].Advertised = true
err := hsdb.db.Save(&currentRoutes[pos]).Error err := tx.Save(&currentRoutes[pos]).Error
if err != nil { if err != nil {
return sendUpdate, err return sendUpdate, err
} }
@ -398,7 +390,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
} else if route.Advertised { } else if route.Advertised {
currentRoutes[pos].Advertised = false currentRoutes[pos].Advertised = false
currentRoutes[pos].Enabled = false currentRoutes[pos].Enabled = false
err := hsdb.db.Save(&currentRoutes[pos]).Error err := tx.Save(&currentRoutes[pos]).Error
if err != nil { if err != nil {
return sendUpdate, err return sendUpdate, err
} }
@ -413,7 +405,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
Advertised: true, Advertised: true,
Enabled: false, Enabled: false,
} }
err := hsdb.db.Create(&route).Error err := tx.Create(&route).Error
if err != nil { if err != nil {
return sendUpdate, err return sendUpdate, err
} }
@ -425,127 +417,89 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route // EnsureFailoverRouteIsAvailable takes a node and checks if the node's route
// currently have a functioning host that exposes the network. // currently have a functioning host that exposes the network.
func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error { func EnsureFailoverRouteIsAvailable(
nodeRoutes, err := hsdb.getNodeRoutes(node) tx *gorm.DB,
isConnected map[key.MachinePublic]bool,
node *types.Node,
) (*types.StateUpdate, error) {
nodeRoutes, err := GetNodeRoutes(tx, node)
if err != nil { if err != nil {
return nil return nil, nil
} }
var changedNodes types.Nodes
for _, nodeRoute := range nodeRoutes { for _, nodeRoute := range nodeRoutes {
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix)) routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
if err != nil { if err != nil {
return err return nil, err
} }
for _, route := range routes { for _, route := range routes {
if route.IsPrimary { if route.IsPrimary {
// if we have a primary route, and the node is connected // if we have a primary route, and the node is connected
// nothing needs to be done. // nothing needs to be done.
if hsdb.notifier.IsConnected(route.Node.MachineKey) { if isConnected[route.Node.MachineKey] {
continue continue
} }
// if not, we need to failover the route // if not, we need to failover the route
err := hsdb.failoverRouteWithNotify(&route) update, err := failoverRouteReturnUpdate(tx, isConnected, &route)
if err != nil { if err != nil {
return err return nil, err
}
if update != nil {
changedNodes = append(changedNodes, update.ChangeNodes...)
} }
} }
} }
} }
return nil if len(changedNodes) != 0 {
} return &types.StateUpdate{
func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error {
routes, err := hsdb.getNodeRoutes(node)
if err != nil {
return nil
}
var changedKeys []key.MachinePublic
for _, route := range routes {
changed, err := hsdb.failoverRoute(&route)
if err != nil {
return err
}
changedKeys = append(changedKeys, changed...)
}
changedKeys = lo.Uniq(changedKeys)
var nodes types.Nodes
for _, key := range changedKeys {
node, err := hsdb.GetNodeByMachineKey(key)
if err != nil {
return err
}
nodes = append(nodes, node)
}
if nodes != nil {
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
ChangeNodes: nodes, ChangeNodes: changedNodes,
Message: "called from db.FailoverNodeRoutesWithNotify", Message: "called from db.EnsureFailoverRouteIsAvailable",
} }, nil
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
} }
return nil return nil, nil
} }
func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error { func failoverRouteReturnUpdate(
changedKeys, err := hsdb.failoverRoute(r) tx *gorm.DB,
isConnected map[key.MachinePublic]bool,
r *types.Route,
) (*types.StateUpdate, error) {
changedKeys, err := failoverRoute(tx, isConnected, r)
if err != nil { if err != nil {
return err return nil, err
} }
log.Trace().
Interface("isConnected", isConnected).
Interface("changedKeys", changedKeys).
Msg("building route failover")
if len(changedKeys) == 0 { if len(changedKeys) == 0 {
return nil return nil, nil
} }
var nodes types.Nodes var nodes types.Nodes
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("loading machines with new primary routes from db")
for _, key := range changedKeys { for _, key := range changedKeys {
node, err := hsdb.getNodeByMachineKey(key) node, err := GetNodeByMachineKey(tx, key)
if err != nil { if err != nil {
return err return nil, err
} }
nodes = append(nodes, node) nodes = append(nodes, node)
} }
log.Trace(). return &types.StateUpdate{
Str("hostname", r.Node.Hostname).
Msg("notifying peers about primary route change")
if nodes != nil {
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
ChangeNodes: nodes, ChangeNodes: nodes,
Message: "called from db.failoverRouteWithNotify", Message: "called from db.failoverRouteReturnUpdate",
} }, nil
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
}
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("notified peers about primary route change")
return nil
} }
// failoverRoute takes a route that is no longer available, // failoverRoute takes a route that is no longer available,
@ -556,12 +510,16 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
// //
// and tries to find a new route to take over its place. // and tries to find a new route to take over its place.
// If the given route was not primary, it returns early. // If the given route was not primary, it returns early.
func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) { func failoverRoute(
tx *gorm.DB,
isConnected map[key.MachinePublic]bool,
r *types.Route,
) ([]key.MachinePublic, error) {
if r == nil { if r == nil {
return nil, nil return nil, nil
} }
// This route is not a primary route, and it isnt // This route is not a primary route, and it is not
// being served to nodes. // being served to nodes.
if !r.IsPrimary { if !r.IsPrimary {
return nil, nil return nil, nil
@ -572,7 +530,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
return nil, nil return nil, nil
} }
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix)) routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -585,14 +543,18 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
continue continue
} }
if hsdb.notifier.IsConnected(route.Node.MachineKey) { if !route.Enabled {
continue
}
if isConnected[route.Node.MachineKey] {
newPrimary = &routes[idx] newPrimary = &routes[idx]
break break
} }
} }
// If a new route was not found/available, // If a new route was not found/available,
// return with an error. // return without an error.
// We do not want to update the database as // We do not want to update the database as
// the one currently marked as primary is the // the one currently marked as primary is the
// best we got. // best we got.
@ -606,7 +568,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
// Remove primary from the old route // Remove primary from the old route
r.IsPrimary = false r.IsPrimary = false
err = hsdb.db.Save(&r).Error err = tx.Save(&r).Error
if err != nil { if err != nil {
log.Error().Err(err).Msg("error disabling new primary route") log.Error().Err(err).Msg("error disabling new primary route")
@ -619,7 +581,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
// Set primary for the new primary // Set primary for the new primary
newPrimary.IsPrimary = true newPrimary.IsPrimary = true
err = hsdb.db.Save(&newPrimary).Error err = tx.Save(&newPrimary).Error
if err != nil { if err != nil {
log.Error().Err(err).Msg("error enabling new primary route") log.Error().Err(err).Msg("error enabling new primary route")
@ -634,19 +596,26 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
} }
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
func (hsdb *HSDatabase) EnableAutoApprovedRoutes( func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
aclPolicy *policy.ACLPolicy, aclPolicy *policy.ACLPolicy,
node *types.Node, node *types.Node,
) error { ) (*types.StateUpdate, error) {
hsdb.mu.Lock() return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
defer hsdb.mu.Unlock() return EnableAutoApprovedRoutes(tx, aclPolicy, node)
})
if len(node.IPAddresses) == 0 {
return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
} }
routes, err := hsdb.getNodeAdvertisedRoutes(node) // EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
func EnableAutoApprovedRoutes(
tx *gorm.DB,
aclPolicy *policy.ACLPolicy,
node *types.Node,
) (*types.StateUpdate, error) {
if len(node.IPAddresses) == 0 {
return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
}
routes, err := GetNodeAdvertisedRoutes(tx, node)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error(). log.Error().
Caller(). Caller().
@ -654,9 +623,11 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
Str("node", node.Hostname). Str("node", node.Hostname).
Msg("Could not get advertised routes for node") Msg("Could not get advertised routes for node")
return err return nil, err
} }
log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
approvedRoutes := types.Routes{} approvedRoutes := types.Routes{}
for _, advertisedRoute := range routes { for _, advertisedRoute := range routes {
@ -673,9 +644,16 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
Uint64("nodeId", node.ID). Uint64("nodeId", node.ID).
Msg("Failed to resolve autoApprovers for advertised route") Msg("Failed to resolve autoApprovers for advertised route")
return err return nil, err
} }
log.Trace().
Str("node", node.Hostname).
Str("user", node.User.Name).
Strs("routeApprovers", routeApprovers).
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
Msg("looking up route for autoapproving")
for _, approvedAlias := range routeApprovers { for _, approvedAlias := range routeApprovers {
if approvedAlias == node.User.Name { if approvedAlias == node.User.Name {
approvedRoutes = append(approvedRoutes, advertisedRoute) approvedRoutes = append(approvedRoutes, advertisedRoute)
@ -687,7 +665,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
Str("alias", approvedAlias). Str("alias", approvedAlias).
Msg("Failed to expand alias when processing autoApprovers policy") Msg("Failed to expand alias when processing autoApprovers policy")
return err return nil, err
} }
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first // approvedIPs should contain all of node's IPs if it matches the rule, so check for first
@ -698,17 +676,25 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
} }
} }
update := &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{},
Message: "created in db.EnableAutoApprovedRoutes",
}
for _, approvedRoute := range approvedRoutes { for _, approvedRoute := range approvedRoutes {
err := hsdb.enableRoute(uint64(approvedRoute.ID)) perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID))
if err != nil { if err != nil {
log.Err(err). log.Err(err).
Str("approvedRoute", approvedRoute.String()). Str("approvedRoute", approvedRoute.String()).
Uint64("nodeId", node.ID). Uint64("nodeId", node.ID).
Msg("Failed to enable approved route") Msg("Failed to enable approved route")
return err return nil, err
}
} }
return nil update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...)
}
return update, nil
} }

View file

@ -24,7 +24,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_get_route_node") _, err = db.getNode("test", "test_get_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix("10.0.0.0/24") route, err := netip.ParsePrefix("10.0.0.0/24")
@ -42,7 +42,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
db.db.Save(&node) db.DB.Save(&node)
su, err := db.SaveNodeRoutes(&node) su, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -52,10 +52,11 @@ func (s *Suite) TestGetRoutes(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(advertisedRoutes), check.Equals, 1) c.Assert(len(advertisedRoutes), check.Equals, 1)
err = db.enableRoutes(&node, "192.168.0.0/24") // TODO(kradalby): check state update
_, err = db.enableRoutes(&node, "192.168.0.0/24")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
err = db.enableRoutes(&node, "10.0.0.0/24") _, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
@ -66,7 +67,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node") _, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix( route, err := netip.ParsePrefix(
@ -91,7 +92,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
db.db.Save(&node) db.DB.Save(&node)
sendUpdate, err := db.SaveNodeRoutes(&node) sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -106,10 +107,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(noEnabledRoutes), check.Equals, 0) c.Assert(len(noEnabledRoutes), check.Equals, 0)
err = db.enableRoutes(&node, "192.168.0.0/24") _, err = db.enableRoutes(&node, "192.168.0.0/24")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
err = db.enableRoutes(&node, "10.0.0.0/24") _, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes, err := db.GetEnabledRoutes(&node) enabledRoutes, err := db.GetEnabledRoutes(&node)
@ -117,14 +118,14 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
c.Assert(len(enabledRoutes), check.Equals, 1) c.Assert(len(enabledRoutes), check.Equals, 1)
// Adding it twice will just let it pass through // Adding it twice will just let it pass through
err = db.enableRoutes(&node, "10.0.0.0/24") _, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node) enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
err = db.enableRoutes(&node, "150.0.10.0/25") _, err = db.enableRoutes(&node, "150.0.10.0/25")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node) enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node)
@ -139,7 +140,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node") _, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix( route, err := netip.ParsePrefix(
@ -163,16 +164,16 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo1, Hostinfo: &hostInfo1,
} }
db.db.Save(&node1) db.DB.Save(&node1)
sendUpdate, err := db.SaveNodeRoutes(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false) c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, route.String()) _, err = db.enableRoutes(&node1, route.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, route2.String()) _, err = db.enableRoutes(&node1, route2.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
hostInfo2 := tailcfg.Hostinfo{ hostInfo2 := tailcfg.Hostinfo{
@ -186,13 +187,13 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo2, Hostinfo: &hostInfo2,
} }
db.db.Save(&node2) db.DB.Save(&node2)
sendUpdate, err = db.SaveNodeRoutes(&node2) sendUpdate, err = db.SaveNodeRoutes(&node2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false) c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node2, route2.String()) _, err = db.enableRoutes(&node2, route2.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes1, err := db.GetEnabledRoutes(&node1) enabledRoutes1, err := db.GetEnabledRoutes(&node1)
@ -219,7 +220,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node") _, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix( prefix, err := netip.ParsePrefix(
@ -246,22 +247,23 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
Hostinfo: &hostInfo1, Hostinfo: &hostInfo1,
LastSeen: &now, LastSeen: &now,
} }
db.db.Save(&node1) db.DB.Save(&node1)
sendUpdate, err := db.SaveNodeRoutes(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false) c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, prefix.String()) _, err = db.enableRoutes(&node1, prefix.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, prefix2.String()) _, err = db.enableRoutes(&node1, prefix2.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
routes, err := db.GetNodeRoutes(&node1) routes, err := db.GetNodeRoutes(&node1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = db.DeleteRoute(uint64(routes[0].ID)) // TODO(kradalby): check stateupdate
_, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes1, err := db.GetEnabledRoutes(&node1) enabledRoutes1, err := db.GetEnabledRoutes(&node1)
@ -269,17 +271,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
c.Assert(len(enabledRoutes1), check.Equals, 1) c.Assert(len(enabledRoutes1), check.Equals, 1)
} }
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
func TestFailoverRoute(t *testing.T) { func TestFailoverRoute(t *testing.T) {
ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
// TODO(kradalby): Count/verify updates
var sink chan types.StateUpdate
go func() {
for range sink {
}
}()
machineKeys := []key.MachinePublic{ machineKeys := []key.MachinePublic{
key.NewMachine().Public(), key.NewMachine().Public(),
key.NewMachine().Public(), key.NewMachine().Public(),
@ -291,6 +285,7 @@ func TestFailoverRoute(t *testing.T) {
name string name string
failingRoute types.Route failingRoute types.Route
routes types.Routes routes types.Routes
isConnected map[key.MachinePublic]bool
want []key.MachinePublic want []key.MachinePublic
wantErr bool wantErr bool
}{ }{
@ -371,6 +366,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0], MachineKey: machineKeys[0],
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true,
}, },
routes: types.Routes{ routes: types.Routes{
types.Route{ types.Route{
@ -382,6 +378,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0], MachineKey: machineKeys[0],
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true,
}, },
types.Route{ types.Route{
Model: gorm.Model{ Model: gorm.Model{
@ -392,8 +389,13 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[1], MachineKey: machineKeys[1],
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true,
}, },
}, },
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: false,
machineKeys[1]: true,
},
want: []key.MachinePublic{ want: []key.MachinePublic{
machineKeys[0], machineKeys[0],
machineKeys[1], machineKeys[1],
@ -411,6 +413,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0], MachineKey: machineKeys[0],
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true,
}, },
routes: types.Routes{ routes: types.Routes{
types.Route{ types.Route{
@ -422,6 +425,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0], MachineKey: machineKeys[0],
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true,
}, },
types.Route{ types.Route{
Model: gorm.Model{ Model: gorm.Model{
@ -432,6 +436,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[1], MachineKey: machineKeys[1],
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true,
}, },
}, },
want: nil, want: nil,
@ -448,6 +453,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[1], MachineKey: machineKeys[1],
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true,
}, },
routes: types.Routes{ routes: types.Routes{
types.Route{ types.Route{
@ -459,6 +465,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0], MachineKey: machineKeys[0],
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true,
}, },
types.Route{ types.Route{
Model: gorm.Model{ Model: gorm.Model{
@ -469,6 +476,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[1], MachineKey: machineKeys[1],
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true,
}, },
types.Route{ types.Route{
Model: gorm.Model{ Model: gorm.Model{
@ -479,8 +487,14 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[2], MachineKey: machineKeys[2],
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true,
}, },
}, },
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: true,
machineKeys[1]: true,
machineKeys[2]: true,
},
want: []key.MachinePublic{ want: []key.MachinePublic{
machineKeys[1], machineKeys[1],
machineKeys[0], machineKeys[0],
@ -498,6 +512,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0], MachineKey: machineKeys[0],
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true,
}, },
routes: types.Routes{ routes: types.Routes{
types.Route{ types.Route{
@ -509,6 +524,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0], MachineKey: machineKeys[0],
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true,
}, },
// Offline // Offline
types.Route{ types.Route{
@ -520,8 +536,13 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[3], MachineKey: machineKeys[3],
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true,
}, },
}, },
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: true,
machineKeys[3]: false,
},
want: nil, want: nil,
wantErr: false, wantErr: false,
}, },
@ -536,6 +557,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0], MachineKey: machineKeys[0],
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true,
}, },
routes: types.Routes{ routes: types.Routes{
types.Route{ types.Route{
@ -547,6 +569,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0], MachineKey: machineKeys[0],
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true,
}, },
// Offline // Offline
types.Route{ types.Route{
@ -558,6 +581,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[3], MachineKey: machineKeys[3],
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true,
}, },
types.Route{ types.Route{
Model: gorm.Model{ Model: gorm.Model{
@ -568,14 +592,61 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[1], MachineKey: machineKeys[1],
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true,
}, },
}, },
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: false,
machineKeys[1]: true,
machineKeys[3]: false,
},
want: []key.MachinePublic{ want: []key.MachinePublic{
machineKeys[0], machineKeys[0],
machineKeys[1], machineKeys[1],
}, },
wantErr: false, wantErr: false,
}, },
{
name: "failover-primary-none-enabled",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
Enabled: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
Enabled: true,
},
// not enabled
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: false,
Enabled: false,
},
},
want: nil,
wantErr: false,
},
} }
for _, tt := range tests { for _, tt := range tests {
@ -583,13 +654,14 @@ func TestFailoverRoute(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "failover-db-test") tmpDir, err := os.MkdirTemp("", "failover-db-test")
assert.NoError(t, err) assert.NoError(t, err)
notif := notifier.NewNotifier()
db, err = NewHeadscaleDatabase( db, err = NewHeadscaleDatabase(
"sqlite3", types.DatabaseConfig{
tmpDir+"/headscale_test.db", Type: "sqlite3",
false, Sqlite: types.SqliteConfig{
notif, Path: tmpDir + "/headscale_test.db",
},
},
notifier.NewNotifier(),
[]netip.Prefix{ []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"), netip.MustParsePrefix("10.27.0.0/23"),
}, },
@ -597,23 +669,15 @@ func TestFailoverRoute(t *testing.T) {
) )
assert.NoError(t, err) assert.NoError(t, err)
// Pretend that all the nodes are connected to control
for idx, key := range machineKeys {
// Pretend one node is offline
if idx == 3 {
continue
}
notif.AddNode(key, sink)
}
for _, route := range tt.routes { for _, route := range tt.routes {
if err := db.db.Save(&route).Error; err != nil { if err := db.DB.Save(&route).Error; err != nil {
t.Fatalf("failed to create route: %s", err) t.Fatalf("failed to create route: %s", err)
} }
} }
got, err := db.failoverRoute(&tt.failingRoute) got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) {
return failoverRoute(tx, tt.isConnected, &tt.failingRoute)
})
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
@ -627,3 +691,231 @@ func TestFailoverRoute(t *testing.T) {
}) })
} }
} }
// func TestDisableRouteFailover(t *testing.T) {
// machineKeys := []key.MachinePublic{
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// }
// tests := []struct {
// name string
// nodes types.Nodes
// routeID uint64
// isConnected map[key.MachinePublic]bool
// wantMachineKey key.MachinePublic
// wantErr string
// }{
// {
// name: "single-route",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// Node: types.Node{
// MachineKey: machineKeys[0],
// },
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// routeID: 1,
// wantMachineKey: machineKeys[0],
// },
// {
// name: "failover-simple",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// {
// name: "no-failover-offline",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// isConnected: map[key.MachinePublic]bool{
// machineKeys[0]: true,
// machineKeys[1]: false,
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// {
// name: "failover-to-online",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// isConnected: map[key.MachinePublic]bool{
// machineKeys[0]: true,
// machineKeys[1]: true,
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "")
// assert.NoError(t, err)
// // bootstrap db
// datab.DB.Transaction(func(tx *gorm.DB) error {
// for _, node := range tt.nodes {
// err := tx.Save(node).Error
// if err != nil {
// return err
// }
// _, err = SaveNodeRoutes(tx, node)
// if err != nil {
// return err
// }
// }
// return nil
// })
// got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
// return DisableRoute(tx, tt.routeID, tt.isConnected)
// })
// // if (err.Error() != "") != tt.wantErr {
// // t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
// // return
// // }
// if len(got.ChangeNodes) != 1 {
// t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes))
// }
// if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" {
// t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff)
// }
// })
// }
// }

View file

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"gopkg.in/check.v1" "gopkg.in/check.v1"
) )
@ -45,9 +46,12 @@ func (s *Suite) ResetDB(c *check.C) {
log.Printf("database path: %s", tmpDir+"/headscale_test.db") log.Printf("database path: %s", tmpDir+"/headscale_test.db")
db, err = NewHeadscaleDatabase( db, err = NewHeadscaleDatabase(
"sqlite3", types.DatabaseConfig{
tmpDir+"/headscale_test.db", Type: "sqlite3",
false, Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
notifier.NewNotifier(), notifier.NewNotifier(),
[]netip.Prefix{ []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"), netip.MustParsePrefix("10.27.0.0/23"),

View file

@ -15,22 +15,25 @@ var (
ErrUserStillHasNodes = errors.New("user not empty: node(s) found") ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
) )
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.User, error) {
return CreateUser(tx, name)
})
}
// CreateUser creates a new User. Returns error if could not be created // CreateUser creates a new User. Returns error if could not be created
// or another user already exists. // or another user already exists.
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
err := util.CheckForFQDNRules(name) err := util.CheckForFQDNRules(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user := types.User{} user := types.User{}
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { if err := tx.Where("name = ?", name).First(&user).Error; err == nil {
return nil, ErrUserExists return nil, ErrUserExists
} }
user.Name = name user.Name = name
if err := hsdb.db.Create(&user).Error; err != nil { if err := tx.Create(&user).Error; err != nil {
log.Error(). log.Error().
Str("func", "CreateUser"). Str("func", "CreateUser").
Err(err). Err(err).
@ -42,18 +45,21 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
return &user, nil return &user, nil
} }
func (hsdb *HSDatabase) DestroyUser(name string) error {
return hsdb.Write(func(tx *gorm.DB) error {
return DestroyUser(tx, name)
})
}
// DestroyUser destroys a User. Returns error if the User does // DestroyUser destroys a User. Returns error if the User does
// not exist or if there are nodes associated with it. // not exist or if there are nodes associated with it.
func (hsdb *HSDatabase) DestroyUser(name string) error { func DestroyUser(tx *gorm.DB, name string) error {
hsdb.mu.Lock() user, err := GetUser(tx, name)
defer hsdb.mu.Unlock()
user, err := hsdb.getUser(name)
if err != nil { if err != nil {
return ErrUserNotFound return ErrUserNotFound
} }
nodes, err := hsdb.listNodesByUser(name) nodes, err := ListNodesByUser(tx, name)
if err != nil { if err != nil {
return err return err
} }
@ -61,32 +67,35 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
return ErrUserStillHasNodes return ErrUserStillHasNodes
} }
keys, err := hsdb.listPreAuthKeys(name) keys, err := ListPreAuthKeys(tx, name)
if err != nil { if err != nil {
return err return err
} }
for _, key := range keys { for _, key := range keys {
err = hsdb.destroyPreAuthKey(key) err = DestroyPreAuthKey(tx, key)
if err != nil { if err != nil {
return err return err
} }
} }
if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil { if result := tx.Unscoped().Delete(&user); result.Error != nil {
return result.Error return result.Error
} }
return nil return nil
} }
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
return hsdb.Write(func(tx *gorm.DB) error {
return RenameUser(tx, oldName, newName)
})
}
// RenameUser renames a User. Returns error if the User does // RenameUser renames a User. Returns error if the User does
// not exist or if another User exists with the new name. // not exist or if another User exists with the new name.
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { func RenameUser(tx *gorm.DB, oldName, newName string) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
var err error var err error
oldUser, err := hsdb.getUser(oldName) oldUser, err := GetUser(tx, oldName)
if err != nil { if err != nil {
return err return err
} }
@ -94,7 +103,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
if err != nil { if err != nil {
return err return err
} }
_, err = hsdb.getUser(newName) _, err = GetUser(tx, newName)
if err == nil { if err == nil {
return ErrUserExists return ErrUserExists
} }
@ -104,24 +113,22 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
oldUser.Name = newName oldUser.Name = newName
if result := hsdb.db.Save(&oldUser); result.Error != nil { if result := tx.Save(&oldUser); result.Error != nil {
return result.Error return result.Error
} }
return nil return nil
} }
// GetUser fetches a user by name.
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
hsdb.mu.RLock() return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
defer hsdb.mu.RUnlock() return GetUser(rx, name)
})
return hsdb.getUser(name)
} }
func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { func GetUser(tx *gorm.DB, name string) (*types.User, error) {
user := types.User{} user := types.User{}
if result := hsdb.db.First(&user, "name = ?", name); errors.Is( if result := tx.First(&user, "name = ?", name); errors.Is(
result.Error, result.Error,
gorm.ErrRecordNotFound, gorm.ErrRecordNotFound,
) { ) {
@ -131,17 +138,16 @@ func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
return &user, nil return &user, nil
} }
// ListUsers gets all the existing users.
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
hsdb.mu.RLock() return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
defer hsdb.mu.RUnlock() return ListUsers(rx)
})
return hsdb.listUsers()
} }
func (hsdb *HSDatabase) listUsers() ([]types.User, error) { // ListUsers gets all the existing users.
func ListUsers(tx *gorm.DB) ([]types.User, error) {
users := []types.User{} users := []types.User{}
if err := hsdb.db.Find(&users).Error; err != nil { if err := tx.Find(&users).Error; err != nil {
return nil, err return nil, err
} }
@ -149,46 +155,42 @@ func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
} }
// ListNodesByUser gets all the nodes in a given user. // ListNodesByUser gets all the nodes in a given user.
func (hsdb *HSDatabase) ListNodesByUser(name string) (types.Nodes, error) { func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listNodesByUser(name)
}
func (hsdb *HSDatabase) listNodesByUser(name string) (types.Nodes, error) {
err := util.CheckForFQDNRules(name) err := util.CheckForFQDNRules(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := hsdb.getUser(name) user, err := GetUser(tx, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
nodes := types.Nodes{} nodes := types.Nodes{}
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
return nil, err return nil, err
} }
return nodes, nil return nodes, nil
} }
// AssignNodeToUser assigns a Node to a user.
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error { func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error {
hsdb.mu.Lock() return hsdb.Write(func(tx *gorm.DB) error {
defer hsdb.mu.Unlock() return AssignNodeToUser(tx, node, username)
})
}
// AssignNodeToUser assigns a Node to a user.
func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
err := util.CheckForFQDNRules(username) err := util.CheckForFQDNRules(username)
if err != nil { if err != nil {
return err return err
} }
user, err := hsdb.getUser(username) user, err := GetUser(tx, username)
if err != nil { if err != nil {
return err return err
} }
node.User = *user node.User = *user
if result := hsdb.db.Save(&node); result.Error != nil { if result := tx.Save(&node); result.Error != nil {
return result.Error return result.Error
} }

View file

@ -36,7 +36,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
err = db.DestroyUser("test") err = db.DestroyUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
result := db.db.Preload("User").First(&pak, "key = ?", pak.Key) result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
// destroying a user also deletes all associated preauthkeys // destroying a user also deletes all associated preauthkeys
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
@ -53,7 +53,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(&node) db.DB.Save(&node)
err = db.DestroyUser("test") err = db.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserStillHasNodes) c.Assert(err, check.Equals, ErrUserStillHasNodes)
@ -105,7 +105,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
db.db.Save(&node) db.DB.Save(&node)
c.Assert(node.UserID, check.Equals, oldUser.ID) c.Assert(node.UserID, check.Equals, oldUser.ID)
err = db.AssignNodeToUser(&node, newUser.Name) err = db.AssignNodeToUser(&node, newUser.Name)

View file

@ -84,6 +84,8 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
RegionID: d.cfg.ServerRegionID, RegionID: d.cfg.ServerRegionID,
HostName: host, HostName: host,
DERPPort: port, DERPPort: port,
IPv4: d.cfg.IPv4,
IPv6: d.cfg.IPv6,
}, },
}, },
} }
@ -99,6 +101,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
localDERPregion.Nodes[0].STUNPort = portSTUN localDERPregion.Nodes[0].STUNPort = portSTUN
log.Info().Caller().Msgf("DERP region: %+v", localDERPregion) log.Info().Caller().Msgf("DERP region: %+v", localDERPregion)
log.Info().Caller().Msgf("DERP Nodes[0]: %+v", localDERPregion.Nodes[0])
return localDERPregion, nil return localDERPregion, nil
} }
@ -208,6 +211,7 @@ func DERPProbeHandler(
// The initial implementation is here https://github.com/tailscale/tailscale/pull/1406 // The initial implementation is here https://github.com/tailscale/tailscale/pull/1406
// They have a cache, but not clear if that is really necessary at Headscale, uh, scale. // They have a cache, but not clear if that is really necessary at Headscale, uh, scale.
// An example implementation is found here https://derp.tailscale.com/bootstrap-dns // An example implementation is found here https://derp.tailscale.com/bootstrap-dns
// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL.
func DERPBootstrapDNSHandler( func DERPBootstrapDNSHandler(
derpMap *tailcfg.DERPMap, derpMap *tailcfg.DERPMap,
) func(http.ResponseWriter, *http.Request) { ) func(http.ResponseWriter, *http.Request) {

View file

@ -8,11 +8,13 @@ import (
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@ -136,12 +138,14 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
ctx context.Context, ctx context.Context,
request *v1.ExpirePreAuthKeyRequest, request *v1.ExpirePreAuthKeyRequest,
) (*v1.ExpirePreAuthKeyResponse, error) { ) (*v1.ExpirePreAuthKeyResponse, error) {
preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key) err := api.h.db.DB.Transaction(func(tx *gorm.DB) error {
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
if err != nil { if err != nil {
return nil, err return err
} }
err = api.h.db.ExpirePreAuthKey(preAuthKey) return db.ExpirePreAuthKey(tx, preAuthKey)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -181,17 +185,31 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, err return nil, err
} }
node, err := api.h.db.RegisterNodeFromAuthCallback( node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
return db.RegisterNodeFromAuthCallback(
tx,
api.h.registrationCache, api.h.registrationCache,
mkey, mkey,
request.GetUser(), request.GetUser(),
nil, nil,
util.RegisterMethodCLI, util.RegisterMethodCLI,
api.h.cfg.IPPrefixes,
) )
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from api.RegisterNode",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-registernode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
} }
@ -217,25 +235,35 @@ func (api headscaleV1APIServer) SetTags(
ctx context.Context, ctx context.Context,
request *v1.SetTagsRequest, request *v1.SetTagsRequest,
) (*v1.SetTagsResponse, error) { ) (*v1.SetTagsResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId()) for _, tag := range request.GetTags() {
err := validateTag(tag)
if err != nil {
return nil, err
}
}
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.SetTags(tx, request.GetNodeId(), request.GetTags())
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, tag := range request.GetTags() { return db.GetNodeByID(tx, request.GetNodeId())
err := validateTag(tag) })
if err != nil { if err != nil {
return &v1.SetTagsResponse{ return &v1.SetTagsResponse{
Node: nil, Node: nil,
}, status.Error(codes.InvalidArgument, err.Error()) }, status.Error(codes.InvalidArgument, err.Error())
} }
}
err = api.h.db.SetTags(node, request.GetTags()) stateUpdate := types.StateUpdate{
if err != nil { Type: types.StatePeerChanged,
return &v1.SetTagsResponse{ ChangeNodes: types.Nodes{node},
Node: nil, Message: "called from api.SetTags",
}, status.Error(codes.Internal, err.Error()) }
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
} }
log.Trace(). log.Trace().
@ -270,11 +298,21 @@ func (api headscaleV1APIServer) DeleteNode(
err = api.h.db.DeleteNode( err = api.h.db.DeleteNode(
node, node,
api.h.nodeNotifier.ConnectedMap(),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, stateUpdate)
}
return &v1.DeleteNodeResponse{}, nil return &v1.DeleteNodeResponse{}, nil
} }
@ -282,17 +320,38 @@ func (api headscaleV1APIServer) ExpireNode(
ctx context.Context, ctx context.Context,
request *v1.ExpireNodeRequest, request *v1.ExpireNodeRequest,
) (*v1.ExpireNodeResponse, error) { ) (*v1.ExpireNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId()) now := time.Now()
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
db.NodeSetExpiry(
tx,
request.GetNodeId(),
now,
)
return db.GetNodeByID(tx, request.GetNodeId())
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
now := time.Now() selfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if selfUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByMachineKey(
ctx,
selfUpdate,
node.MachineKey)
}
api.h.db.NodeSetExpiry( stateUpdate := types.StateUpdateExpire(node.ID, now)
node, if stateUpdate.Valid() {
now, ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
) api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
@ -306,19 +365,32 @@ func (api headscaleV1APIServer) RenameNode(
ctx context.Context, ctx context.Context,
request *v1.RenameNodeRequest, request *v1.RenameNodeRequest,
) (*v1.RenameNodeResponse, error) { ) (*v1.RenameNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId()) node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
if err != nil { err := db.RenameNode(
return nil, err tx,
} request.GetNodeId(),
err = api.h.db.RenameNode(
node,
request.GetNewName(), request.GetNewName(),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
return db.GetNodeByID(tx, request.GetNodeId())
})
if err != nil {
return nil, err
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from api.RenameNode",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
Str("new_name", request.GetNewName()). Str("new_name", request.GetNewName()).
@ -331,8 +403,11 @@ func (api headscaleV1APIServer) ListNodes(
ctx context.Context, ctx context.Context,
request *v1.ListNodesRequest, request *v1.ListNodesRequest,
) (*v1.ListNodesResponse, error) { ) (*v1.ListNodesResponse, error) {
isConnected := api.h.nodeNotifier.ConnectedMap()
if request.GetUser() != "" { if request.GetUser() != "" {
nodes, err := api.h.db.ListNodesByUser(request.GetUser()) nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
return db.ListNodesByUser(rx, request.GetUser())
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -343,7 +418,7 @@ func (api headscaleV1APIServer) ListNodes(
// Populate the online field based on // Populate the online field based on
// currently connected nodes. // currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) resp.Online = isConnected[node.MachineKey]
response[index] = resp response[index] = resp
} }
@ -362,10 +437,10 @@ func (api headscaleV1APIServer) ListNodes(
// Populate the online field based on // Populate the online field based on
// currently connected nodes. // currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) resp.Online = isConnected[node.MachineKey]
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
&node, node,
) )
resp.InvalidTags = invalidTags resp.InvalidTags = invalidTags
resp.ValidTags = validTags resp.ValidTags = validTags
@ -396,7 +471,9 @@ func (api headscaleV1APIServer) GetRoutes(
ctx context.Context, ctx context.Context,
request *v1.GetRoutesRequest, request *v1.GetRoutesRequest,
) (*v1.GetRoutesResponse, error) { ) (*v1.GetRoutesResponse, error) {
routes, err := api.h.db.GetRoutes() routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) {
return db.GetRoutes(rx)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -410,11 +487,19 @@ func (api headscaleV1APIServer) EnableRoute(
ctx context.Context, ctx context.Context,
request *v1.EnableRouteRequest, request *v1.EnableRouteRequest,
) (*v1.EnableRouteResponse, error) { ) (*v1.EnableRouteResponse, error) {
err := api.h.db.EnableRoute(request.GetRouteId()) update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.EnableRoute(tx, request.GetRouteId())
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
api.h.nodeNotifier.NotifyAll(
ctx, *update)
}
return &v1.EnableRouteResponse{}, nil return &v1.EnableRouteResponse{}, nil
} }
@ -422,11 +507,19 @@ func (api headscaleV1APIServer) DisableRoute(
ctx context.Context, ctx context.Context,
request *v1.DisableRouteRequest, request *v1.DisableRouteRequest,
) (*v1.DisableRouteResponse, error) { ) (*v1.DisableRouteResponse, error) {
err := api.h.db.DisableRoute(request.GetRouteId()) isConnected := api.h.nodeNotifier.ConnectedMap()
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.DisableRoute(tx, request.GetRouteId(), isConnected)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, *update)
}
return &v1.DisableRouteResponse{}, nil return &v1.DisableRouteResponse{}, nil
} }
@ -453,11 +546,19 @@ func (api headscaleV1APIServer) DeleteRoute(
ctx context.Context, ctx context.Context,
request *v1.DeleteRouteRequest, request *v1.DeleteRouteRequest,
) (*v1.DeleteRouteResponse, error) { ) (*v1.DeleteRouteResponse, error) {
err := api.h.db.DeleteRoute(request.GetRouteId()) isConnected := api.h.nodeNotifier.ConnectedMap()
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
api.h.nodeNotifier.NotifyWithIgnore(ctx, *update)
}
return &v1.DeleteRouteResponse{}, nil return &v1.DeleteRouteResponse{}, nil
} }

View file

@ -272,13 +272,26 @@ func (m *Mapper) LiteMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node *types.Node,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
messages ...string,
) ([]byte, error) { ) ([]byte, error) {
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version) resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress) rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
pol,
node,
nodeMapToList(m.peers),
)
if err != nil {
return nil, err
}
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
resp.SSHPolicy = sshPolicy
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
} }
func (m *Mapper) KeepAliveResponse( func (m *Mapper) KeepAliveResponse(
@ -380,9 +393,7 @@ func (m *Mapper) PeerChangedPatchResponse(
} }
if patches, ok := m.patches[uint64(change.NodeID)]; ok { if patches, ok := m.patches[uint64(change.NodeID)]; ok {
patches := append(patches, p) m.patches[uint64(change.NodeID)] = append(patches, p)
m.patches[uint64(change.NodeID)] = patches
} else { } else {
m.patches[uint64(change.NodeID)] = []patch{p} m.patches[uint64(change.NodeID)] = []patch{p}
} }
@ -458,6 +469,8 @@ func (m *Mapper) marshalMapResponse(
switch { switch {
case resp.Peers != nil && len(resp.Peers) > 0: case resp.Peers != nil && len(resp.Peers) > 0:
responseType = "full" responseType = "full"
case isSelfUpdate(messages...):
responseType = "self"
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil: case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil:
responseType = "lite" responseType = "lite"
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0: case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
@ -656,3 +669,13 @@ func appendPeerChanges(
return nil return nil
} }
func isSelfUpdate(messages ...string) bool {
for _, message := range messages {
if strings.Contains(message, types.SelfUpdateIdentifier) {
return true
}
}
return false
}

View file

@ -72,7 +72,7 @@ func tailNode(
} }
var derp string var derp string
if node.Hostinfo.NetInfo != nil { if node.Hostinfo != nil && node.Hostinfo.NetInfo != nil {
derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP) derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP)
} else { } else {
derp = "127.3.3.40:0" // Zero means disconnected or unknown. derp = "127.3.3.40:0" // Zero means disconnected or unknown.

View file

@ -1,6 +1,7 @@
package notifier package notifier
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
@ -14,24 +15,28 @@ import (
type Notifier struct { type Notifier struct {
l sync.RWMutex l sync.RWMutex
nodes map[string]chan<- types.StateUpdate nodes map[string]chan<- types.StateUpdate
connected map[key.MachinePublic]bool
} }
func NewNotifier() *Notifier { func NewNotifier() *Notifier {
return &Notifier{} return &Notifier{
nodes: make(map[string]chan<- types.StateUpdate),
connected: make(map[key.MachinePublic]bool),
}
} }
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) { func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node") log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to add node") defer log.Trace().
Caller().
Str("key", machineKey.ShortString()).
Msg("releasing lock to add node")
n.l.Lock() n.l.Lock()
defer n.l.Unlock() defer n.l.Unlock()
if n.nodes == nil {
n.nodes = make(map[string]chan<- types.StateUpdate)
}
n.nodes[machineKey.String()] = c n.nodes[machineKey.String()] = c
n.connected[machineKey] = true
log.Trace(). log.Trace().
Str("machine_key", machineKey.ShortString()). Str("machine_key", machineKey.ShortString()).
@ -41,16 +46,20 @@ func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpd
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) { func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node") log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to remove node") defer log.Trace().
Caller().
Str("key", machineKey.ShortString()).
Msg("releasing lock to remove node")
n.l.Lock() n.l.Lock()
defer n.l.Unlock() defer n.l.Unlock()
if n.nodes == nil { if len(n.nodes) == 0 {
return return
} }
delete(n.nodes, machineKey.String()) delete(n.nodes, machineKey.String())
n.connected[machineKey] = false
log.Trace(). log.Trace().
Str("machine_key", machineKey.ShortString()). Str("machine_key", machineKey.ShortString()).
@ -64,23 +73,28 @@ func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
n.l.RLock() n.l.RLock()
defer n.l.RUnlock() defer n.l.RUnlock()
if _, ok := n.nodes[machineKey.String()]; ok { return n.connected[machineKey]
return true
} }
return false // TODO(kradalby): This returns a pointer and can be dangerous.
func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool {
return n.connected
} }
func (n *Notifier) NotifyAll(update types.StateUpdate) { func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
n.NotifyWithIgnore(update) n.NotifyWithIgnore(ctx, update)
} }
func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) { func (n *Notifier) NotifyWithIgnore(
ctx context.Context,
update types.StateUpdate,
ignore ...string,
) {
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
defer log.Trace(). defer log.Trace().
Caller(). Caller().
Interface("type", update.Type). Interface("type", update.Type).
Msg("releasing lock, finished notifing") Msg("releasing lock, finished notifying")
n.l.RLock() n.l.RLock()
defer n.l.RUnlock() defer n.l.RUnlock()
@ -90,23 +104,58 @@ func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string)
continue continue
} }
log.Trace().Caller().Str("machine", key).Strs("ignoring", ignore).Msg("sending update") select {
c <- update case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Str("mkey", key).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled")
return
case c <- update:
log.Trace().
Str("mkey", key).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
}
} }
} }
func (n *Notifier) NotifyByMachineKey(update types.StateUpdate, mKey key.MachinePublic) { func (n *Notifier) NotifyByMachineKey(
ctx context.Context,
update types.StateUpdate,
mKey key.MachinePublic,
) {
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
defer log.Trace(). defer log.Trace().
Caller(). Caller().
Interface("type", update.Type). Interface("type", update.Type).
Msg("releasing lock, finished notifing") Msg("releasing lock, finished notifying")
n.l.RLock() n.l.RLock()
defer n.l.RUnlock() defer n.l.RUnlock()
if c, ok := n.nodes[mKey.String()]; ok { if c, ok := n.nodes[mKey.String()]; ok {
c <- update select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Str("mkey", mKey.String()).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled")
return
case c <- update:
log.Trace().
Str("mkey", mKey.String()).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
}
} }
} }

View file

@ -21,6 +21,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gorm.io/gorm"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@ -569,7 +570,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
Str("node", node.Hostname). Str("node", node.Hostname).
Msg("node already registered, reauthenticating") Msg("node already registered, reauthenticating")
err := h.db.NodeSetExpiry(node, expiry) err := h.db.NodeSetExpiry(node.ID, expiry)
if err != nil { if err != nil {
util.LogErr(err, "Failed to refresh node") util.LogErr(err, "Failed to refresh node")
http.Error( http.Error(
@ -623,6 +624,12 @@ func (h *Headscale) validateNodeForOIDCCallback(
util.LogErr(err, "Failed to write response") util.LogErr(err, "Failed to write response")
} }
stateUpdate := types.StateUpdateExpire(node.ID, expiry)
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
return nil, true, nil return nil, true, nil
} }
@ -712,14 +719,22 @@ func (h *Headscale) registerNodeForOIDCCallback(
machineKey *key.MachinePublic, machineKey *key.MachinePublic,
expiry time.Time, expiry time.Time,
) error { ) error {
if _, err := h.db.RegisterNodeFromAuthCallback( if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
if _, err := db.RegisterNodeFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules // TODO(kradalby): find a better way to use the cache across modules
tx,
h.registrationCache, h.registrationCache,
*machineKey, *machineKey,
user.Name, user.Name,
&expiry, &expiry,
util.RegisterMethodOIDC, util.RegisterMethodOIDC,
h.cfg.IPPrefixes,
); err != nil { ); err != nil {
return err
}
return nil
}); err != nil {
util.LogErr(err, "could not register node") util.LogErr(err, "could not register node")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)

View file

@ -250,6 +250,21 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
if node.IPAddresses.InIPSet(expanded) { if node.IPAddresses.InIPSet(expanded) {
dests = append(dests, dest) dests = append(dests, dest)
} }
// If the node exposes routes, ensure they are note removed
// when the filters are reduced.
if node.Hostinfo != nil {
// TODO(kradalby): Evaluate if we should only keep
// the routes if the route is enabled. This will
// require database access in this part of the code.
if len(node.Hostinfo.RoutableIPs) > 0 {
for _, routableIP := range node.Hostinfo.RoutableIPs {
if expanded.ContainsPrefix(routableIP) {
dests = append(dests, dest)
}
}
}
}
} }
if len(dests) > 0 { if len(dests) > 0 {
@ -890,8 +905,14 @@ func (pol *ACLPolicy) TagsOfNode(
validTags := make([]string, 0) validTags := make([]string, 0)
invalidTags := make([]string, 0) invalidTags := make([]string, 0)
// TODO(kradalby): Why is this sometimes nil? coming from tailNode?
if node == nil {
return validTags, invalidTags
}
validTagMap := make(map[string]bool) validTagMap := make(map[string]bool)
invalidTagMap := make(map[string]bool) invalidTagMap := make(map[string]bool)
if node.Hostinfo != nil {
for _, tag := range node.Hostinfo.RequestTags { for _, tag := range node.Hostinfo.RequestTags {
owners, err := expandOwnersFromTag(pol, tag) owners, err := expandOwnersFromTag(pol, tag)
if errors.Is(err, ErrInvalidTag) { if errors.Is(err, ErrInvalidTag) {
@ -917,6 +938,7 @@ func (pol *ACLPolicy) TagsOfNode(
for tag := range validTagMap { for tag := range validTagMap {
validTags = append(validTags, tag) validTags = append(validTags, tag)
} }
}
return validTags, invalidTags return validTags, invalidTags
} }

View file

@ -1901,6 +1901,81 @@ func TestReduceFilterRules(t *testing.T) {
}, },
want: []tailcfg.FilterRule{}, want: []tailcfg.FilterRule{},
}, },
{
name: "1604-subnet-routers-are-preserved",
pol: ACLPolicy{
Groups: Groups{
"group:admins": {"user1"},
},
ACLs: []ACL{
{
Action: "accept",
Sources: []string{"group:admins"},
Destinations: []string{"group:admins:*"},
},
{
Action: "accept",
Sources: []string{"group:admins"},
Destinations: []string{"10.33.0.0/16:*"},
},
},
},
node: &types.Node{
IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.1"),
netip.MustParseAddr("fd7a:115c:a1e0::1"),
},
User: types.User{Name: "user1"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
},
},
},
peers: types.Nodes{
&types.Node{
IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.2"),
netip.MustParseAddr("fd7a:115c:a1e0::2"),
},
User: types.User{Name: "user1"},
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.1/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::1/128",
Ports: tailcfg.PortRangeAny,
},
},
},
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "10.33.0.0/16",
Ports: tailcfg.PortRangeAny,
},
},
},
},
},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -4,12 +4,15 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
xslices "golang.org/x/exp/slices" xslices "golang.org/x/exp/slices"
"gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -125,6 +128,18 @@ func (h *Headscale) handlePoll(
return return
} }
if h.ACLPolicy != nil {
// update routes with peer information
update, err := h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
if err != nil {
logErr(err, "Error running auto approved routes")
}
if update != nil {
sendUpdate = true
}
}
} }
// Services is mostly useful for discovery and not critical, // Services is mostly useful for discovery and not critical,
@ -138,41 +153,68 @@ func (h *Headscale) handlePoll(
} }
if sendUpdate { if sendUpdate {
if err := h.db.NodeSave(node); err != nil { if err := h.db.DB.Save(node).Error; err != nil {
logErr(err, "Failed to persist/update node in the database") logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError) http.Error(writer, "", http.StatusInternalServerError)
return return
} }
// Send an update to all peers to propagate the new routes
// available.
stateUpdate := types.StateUpdate{ stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node}, ChangeNodes: types.Nodes{node},
Message: "called from handlePoll -> update -> new hostinfo", Message: "called from handlePoll -> update -> new hostinfo",
} }
if stateUpdate.Valid() { if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-hostinfochange", node.Hostname)
h.nodeNotifier.NotifyWithIgnore( h.nodeNotifier.NotifyWithIgnore(
ctx,
stateUpdate, stateUpdate,
node.MachineKey.String()) node.MachineKey.String())
} }
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
selfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if selfUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", node.Hostname)
h.nodeNotifier.NotifyByMachineKey(
ctx,
selfUpdate,
node.MachineKey)
}
return return
} }
} }
if err := h.db.NodeSave(node); err != nil { if err := h.db.DB.Save(node).Error; err != nil {
logErr(err, "Failed to persist/update node in the database") logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError) http.Error(writer, "", http.StatusInternalServerError)
return return
} }
// TODO(kradalby): Figure out why patch changes does
// not show up in output from `tailscale debug netmap`.
// stateUpdate := types.StateUpdate{
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{&change},
// }
stateUpdate := types.StateUpdate{ stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch, Type: types.StatePeerChanged,
ChangePatches: []*tailcfg.PeerChange{&change}, ChangeNodes: types.Nodes{node},
} }
if stateUpdate.Valid() { if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
h.nodeNotifier.NotifyWithIgnore( h.nodeNotifier.NotifyWithIgnore(
ctx,
stateUpdate, stateUpdate,
node.MachineKey.String()) node.MachineKey.String())
} }
@ -228,7 +270,7 @@ func (h *Headscale) handlePoll(
} }
} }
if err := h.db.NodeSave(node); err != nil { if err := h.db.DB.Save(node).Error; err != nil {
logErr(err, "Failed to persist/update node in the database") logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError) http.Error(writer, "", http.StatusInternalServerError)
@ -265,7 +307,10 @@ func (h *Headscale) handlePoll(
// update ACLRules with peer informations (to update server tags if necessary) // update ACLRules with peer informations (to update server tags if necessary)
if h.ACLPolicy != nil { if h.ACLPolicy != nil {
// update routes with peer information // update routes with peer information
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) // This state update is ignored as it will be sent
// as part of the whole node
// TODO(kradalby): figure out if that is actually correct
_, err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
if err != nil { if err != nil {
logErr(err, "Error running auto approved routes") logErr(err, "Error running auto approved routes")
} }
@ -301,11 +346,17 @@ func (h *Headscale) handlePoll(
Message: "called from handlePoll -> new node added", Message: "called from handlePoll -> new node added",
} }
if stateUpdate.Valid() { if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-newnode-peers", node.Hostname)
h.nodeNotifier.NotifyWithIgnore( h.nodeNotifier.NotifyWithIgnore(
ctx,
stateUpdate, stateUpdate,
node.MachineKey.String()) node.MachineKey.String())
} }
if len(node.Routes) > 0 {
go h.pollFailoverRoutes(logErr, "new node", node)
}
// Set up the client stream // Set up the client stream
h.pollNetMapStreamWG.Add(1) h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done() defer h.pollNetMapStreamWG.Done()
@ -323,15 +374,9 @@ func (h *Headscale) handlePoll(
keepAliveTicker := time.NewTicker(keepAliveInterval) keepAliveTicker := time.NewTicker(keepAliveInterval)
ctx = context.WithValue(ctx, nodeNameContextKey, node.Hostname) ctx, cancel := context.WithCancel(context.WithValue(ctx, nodeNameContextKey, node.Hostname))
ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
if len(node.Routes) > 0 {
go h.db.EnsureFailoverRouteIsAvailable(node)
}
for { for {
logInfo("Waiting for update on stream channel") logInfo("Waiting for update on stream channel")
select { select {
@ -370,6 +415,17 @@ func (h *Headscale) handlePoll(
var data []byte var data []byte
var err error var err error
// Ensure the node object is updated, for example, there
// might have been a hostinfo update in a sidechannel
// which contains data needed to generate a map response.
node, err = h.db.GetNodeByMachineKey(node.MachineKey)
if err != nil {
logErr(err, "Could not get machine from db")
return
}
startMapResp := time.Now()
switch update.Type { switch update.Type {
case types.StateFullUpdate: case types.StateFullUpdate:
logInfo("Sending Full MapResponse") logInfo("Sending Full MapResponse")
@ -378,6 +434,7 @@ func (h *Headscale) handlePoll(
case types.StatePeerChanged: case types.StatePeerChanged:
logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message)) logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message))
isConnectedMap := h.nodeNotifier.ConnectedMap()
for _, node := range update.ChangeNodes { for _, node := range update.ChangeNodes {
// If a node is not reported to be online, it might be // If a node is not reported to be online, it might be
// because the value is outdated, check with the notifier. // because the value is outdated, check with the notifier.
@ -385,7 +442,7 @@ func (h *Headscale) handlePoll(
// this might be because it has announced itself, but not // this might be because it has announced itself, but not
// reached the stage to actually create the notifier channel. // reached the stage to actually create the notifier channel.
if node.IsOnline != nil && !*node.IsOnline { if node.IsOnline != nil && !*node.IsOnline {
isOnline := h.nodeNotifier.IsConnected(node.MachineKey) isOnline := isConnectedMap[node.MachineKey]
node.IsOnline = &isOnline node.IsOnline = &isOnline
} }
} }
@ -401,7 +458,7 @@ func (h *Headscale) handlePoll(
if len(update.ChangeNodes) == 1 { if len(update.ChangeNodes) == 1 {
logInfo("Sending SelfUpdate MapResponse") logInfo("Sending SelfUpdate MapResponse")
node = update.ChangeNodes[0] node = update.ChangeNodes[0]
data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy) data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy, types.SelfUpdateIdentifier)
} else { } else {
logInfo("SelfUpdate contained too many nodes, this is likely a bug in the code, please report.") logInfo("SelfUpdate contained too many nodes, this is likely a bug in the code, please report.")
} }
@ -416,8 +473,11 @@ func (h *Headscale) handlePoll(
return return
} }
log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startMapResp).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished making map response")
// Only send update if there is change // Only send update if there is change
if data != nil { if data != nil {
startWrite := time.Now()
_, err = writer.Write(data) _, err = writer.Write(data)
if err != nil { if err != nil {
logErr(err, "Could not write the map response") logErr(err, "Could not write the map response")
@ -435,6 +495,7 @@ func (h *Headscale) handlePoll(
return return
} }
log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished writing mapresp to node")
log.Info(). log.Info().
Caller(). Caller().
@ -454,7 +515,7 @@ func (h *Headscale) handlePoll(
go h.updateNodeOnlineStatus(false, node) go h.updateNodeOnlineStatus(false, node)
// Failover the node's routes if any. // Failover the node's routes if any.
go h.db.FailoverNodeRoutesWithNotify(node) go h.pollFailoverRoutes(logErr, "node closing connection", node)
// The connection has been closed, so we can stop polling. // The connection has been closed, so we can stop polling.
return return
@ -467,6 +528,22 @@ func (h *Headscale) handlePoll(
} }
} }
func (h *Headscale) pollFailoverRoutes(logErr func(error, string), where string, node *types.Node) {
update, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.EnsureFailoverRouteIsAvailable(tx, h.nodeNotifier.ConnectedMap(), node)
})
if err != nil {
logErr(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
return
}
if update != nil && !update.Empty() && update.Valid() {
ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.MachineKey.String())
}
}
// updateNodeOnlineStatus records the last seen status of a node and notifies peers // updateNodeOnlineStatus records the last seen status of a node and notifies peers
// about change in their online/offline status. // about change in their online/offline status.
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged. // It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
@ -486,10 +563,13 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
}, },
} }
if statusUpdate.Valid() { if statusUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String()) ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, statusUpdate, node.MachineKey.String())
} }
err := h.db.UpdateLastSeen(node) err := h.db.DB.Transaction(func(tx *gorm.DB) error {
return db.UpdateLastSeen(tx, node.ID, *node.LastSeen)
})
if err != nil { if err != nil {
log.Error().Err(err).Msg("Cannot update node LastSeen") log.Error().Err(err).Msg("Cannot update node LastSeen")

View file

@ -13,7 +13,7 @@ import (
) )
const ( const (
MinimumCapVersion tailcfg.CapabilityVersion = 56 MinimumCapVersion tailcfg.CapabilityVersion = 58
) )
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol // NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol

View file

@ -41,8 +41,12 @@ func (s *Suite) ResetDB(c *check.C) {
} }
cfg := types.Config{ cfg := types.Config{
NoisePrivateKeyPath: tmpDir + "/noise_private.key", NoisePrivateKeyPath: tmpDir + "/noise_private.key",
DBtype: "sqlite3", Database: types.DatabaseConfig{
DBpath: tmpDir + "/headscale_test.db", Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
IPPrefixes: []netip.Prefix{ IPPrefixes: []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"), netip.MustParsePrefix("10.27.0.0/23"),
}, },

View file

@ -1,15 +1,23 @@
package types package types
import ( import (
"context"
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"time"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
const (
SelfUpdateIdentifier = "self-update"
DatabasePostgres = "postgres"
DatabaseSqlite = "sqlite3"
)
var ErrCannotParsePrefix = errors.New("cannot parse prefix") var ErrCannotParsePrefix = errors.New("cannot parse prefix")
type IPPrefix netip.Prefix type IPPrefix netip.Prefix
@ -150,7 +158,9 @@ func (su *StateUpdate) Valid() bool {
} }
case StateSelfUpdate: case StateSelfUpdate:
if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 { if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 {
panic("Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node") panic(
"Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node",
)
} }
case StateDERPUpdated: case StateDERPUpdated:
if su.DERPMap == nil { if su.DERPMap == nil {
@ -160,3 +170,37 @@ func (su *StateUpdate) Valid() bool {
return true return true
} }
// Empty reports if there are any updates in the StateUpdate.
func (su *StateUpdate) Empty() bool {
switch su.Type {
case StatePeerChanged:
return len(su.ChangeNodes) == 0
case StatePeerChangedPatch:
return len(su.ChangePatches) == 0
case StatePeerRemoved:
return len(su.Removed) == 0
}
return false
}
func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate {
return StateUpdate{
Type: StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(nodeID),
KeyExpiry: &expiry,
},
},
}
}
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
ctx2, _ := context.WithTimeout(
context.WithValue(context.WithValue(ctx, "hostname", hostname), "origin", origin),
3*time.Second,
)
return ctx2
}

View file

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model" "github.com/prometheus/common/model"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -20,6 +19,8 @@ import (
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
"github.com/juanfont/headscale/hscontrol/util"
) )
const ( const (
@ -46,16 +47,9 @@ type Config struct {
Log LogConfig Log LogConfig
DisableUpdateCheck bool DisableUpdateCheck bool
DERP DERPConfig Database DatabaseConfig
DBtype string DERP DERPConfig
DBpath string
DBhost string
DBport int
DBname string
DBuser string
DBpass string
DBssl string
TLS TLSConfig TLS TLSConfig
@ -77,6 +71,31 @@ type Config struct {
ACL ACLConfig ACL ACLConfig
} }
type SqliteConfig struct {
Path string
}
type PostgresConfig struct {
Host string
Port int
Name string
User string
Pass string
Ssl string
MaxOpenConnections int
MaxIdleConnections int
ConnMaxIdleTimeSecs int
}
type DatabaseConfig struct {
// Type sets the database type, either "sqlite3" or "postgres"
Type string
Debug bool
Sqlite SqliteConfig
Postgres PostgresConfig
}
type TLSConfig struct { type TLSConfig struct {
CertPath string CertPath string
KeyPath string KeyPath string
@ -112,6 +131,7 @@ type OIDCConfig struct {
type DERPConfig struct { type DERPConfig struct {
ServerEnabled bool ServerEnabled bool
AutomaticallyAddEmbeddedDerpRegion bool
ServerRegionID int ServerRegionID int
ServerRegionCode string ServerRegionCode string
ServerRegionName string ServerRegionName string
@ -121,6 +141,8 @@ type DERPConfig struct {
Paths []string Paths []string
AutoUpdate bool AutoUpdate bool
UpdateFrequency time.Duration UpdateFrequency time.Duration
IPv4 string
IPv6 string
} }
type LogTailConfig struct { type LogTailConfig struct {
@ -162,6 +184,19 @@ func LoadConfig(path string, isFile bool) error {
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.AutomaticEnv() viper.AutomaticEnv()
viper.RegisterAlias("db_type", "database.type")
// SQLite aliases
viper.RegisterAlias("db_path", "database.sqlite.path")
// Postgres aliases
viper.RegisterAlias("db_host", "database.postgres.host")
viper.RegisterAlias("db_port", "database.postgres.port")
viper.RegisterAlias("db_name", "database.postgres.name")
viper.RegisterAlias("db_user", "database.postgres.user")
viper.RegisterAlias("db_pass", "database.postgres.pass")
viper.RegisterAlias("db_ssl", "database.postgres.ssl")
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType) viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType)
@ -173,6 +208,7 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("derp.server.enabled", false) viper.SetDefault("derp.server.enabled", false)
viper.SetDefault("derp.server.stun.enabled", true) viper.SetDefault("derp.server.stun.enabled", true)
viper.SetDefault("derp.server.automatically_add_embedded_derp_region", true)
viper.SetDefault("unix_socket", "/var/run/headscale/headscale.sock") viper.SetDefault("unix_socket", "/var/run/headscale/headscale.sock")
viper.SetDefault("unix_socket_permission", "0o770") viper.SetDefault("unix_socket_permission", "0o770")
@ -184,6 +220,10 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("cli.insecure", false) viper.SetDefault("cli.insecure", false)
viper.SetDefault("db_ssl", false) viper.SetDefault("db_ssl", false)
viper.SetDefault("database.postgres.ssl", false)
viper.SetDefault("database.postgres.max_open_conns", 10)
viper.SetDefault("database.postgres.max_idle_conns", 10)
viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600)
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
viper.SetDefault("oidc.strip_email_domain", true) viper.SetDefault("oidc.strip_email_domain", true)
@ -294,8 +334,14 @@ func GetDERPConfig() DERPConfig {
serverRegionCode := viper.GetString("derp.server.region_code") serverRegionCode := viper.GetString("derp.server.region_code")
serverRegionName := viper.GetString("derp.server.region_name") serverRegionName := viper.GetString("derp.server.region_name")
stunAddr := viper.GetString("derp.server.stun_listen_addr") stunAddr := viper.GetString("derp.server.stun_listen_addr")
privateKeyPath := util.AbsolutePathFromConfigPath(viper.GetString("derp.server.private_key_path")) privateKeyPath := util.AbsolutePathFromConfigPath(
viper.GetString("derp.server.private_key_path"),
)
ipv4 := viper.GetString("derp.server.ipv4")
ipv6 := viper.GetString("derp.server.ipv6")
automaticallyAddEmbeddedDerpRegion := viper.GetBool(
"derp.server.automatically_add_embedded_derp_region",
)
if serverEnabled && stunAddr == "" { if serverEnabled && stunAddr == "" {
log.Fatal(). log.Fatal().
Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true") Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true")
@ -318,6 +364,11 @@ func GetDERPConfig() DERPConfig {
paths := viper.GetStringSlice("derp.paths") paths := viper.GetStringSlice("derp.paths")
if serverEnabled && !automaticallyAddEmbeddedDerpRegion && len(paths) == 0 {
log.Fatal().
Msg("Disabling derp.server.automatically_add_embedded_derp_region requires to configure the derp server in derp.paths")
}
autoUpdate := viper.GetBool("derp.auto_update_enabled") autoUpdate := viper.GetBool("derp.auto_update_enabled")
updateFrequency := viper.GetDuration("derp.update_frequency") updateFrequency := viper.GetDuration("derp.update_frequency")
@ -332,6 +383,9 @@ func GetDERPConfig() DERPConfig {
Paths: paths, Paths: paths,
AutoUpdate: autoUpdate, AutoUpdate: autoUpdate,
UpdateFrequency: updateFrequency, UpdateFrequency: updateFrequency,
IPv4: ipv4,
IPv6: ipv6,
AutomaticallyAddEmbeddedDerpRegion: automaticallyAddEmbeddedDerpRegion,
} }
} }
@ -379,6 +433,45 @@ func GetLogConfig() LogConfig {
} }
} }
func GetDatabaseConfig() DatabaseConfig {
debug := viper.GetBool("database.debug")
type_ := viper.GetString("database.type")
switch type_ {
case DatabaseSqlite, DatabasePostgres:
break
case "sqlite":
type_ = "sqlite3"
default:
log.Fatal().
Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_)
}
return DatabaseConfig{
Type: type_,
Debug: debug,
Sqlite: SqliteConfig{
Path: util.AbsolutePathFromConfigPath(
viper.GetString("database.sqlite.path"),
),
},
Postgres: PostgresConfig{
Host: viper.GetString("database.postgres.host"),
Port: viper.GetInt("database.postgres.port"),
Name: viper.GetString("database.postgres.name"),
User: viper.GetString("database.postgres.user"),
Pass: viper.GetString("database.postgres.pass"),
Ssl: viper.GetString("database.postgres.ssl"),
MaxOpenConnections: viper.GetInt("database.postgres.max_open_conns"),
MaxIdleConnections: viper.GetInt("database.postgres.max_idle_conns"),
ConnMaxIdleTimeSecs: viper.GetInt(
"database.postgres.conn_max_idle_time_secs",
),
},
}
}
func GetDNSConfig() (*tailcfg.DNSConfig, string) { func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config") { if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{} dnsConfig := &tailcfg.DNSConfig{}
@ -580,7 +673,7 @@ func GetHeadscaleConfig() (*Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
oidcClientSecret = string(secretBytes) oidcClientSecret = strings.TrimSpace(string(secretBytes))
} }
return &Config{ return &Config{
@ -607,14 +700,7 @@ func GetHeadscaleConfig() (*Config, error) {
"node_update_check_interval", "node_update_check_interval",
), ),
DBtype: viper.GetString("db_type"), Database: GetDatabaseConfig(),
DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")),
DBhost: viper.GetString("db_host"),
DBport: viper.GetInt("db_port"),
DBname: viper.GetString("db_name"),
DBuser: viper.GetString("db_user"),
DBpass: viper.GetString("db_pass"),
DBssl: viper.GetString("db_ssl"),
TLS: GetTLSConfig(), TLS: GetTLSConfig(),

View file

@ -383,7 +383,7 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri
// inform peers about smaller changes to the node. // inform peers about smaller changes to the node.
// When a field is added to this function, remember to also add it to: // When a field is added to this function, remember to also add it to:
// - node.ApplyPeerChange // - node.ApplyPeerChange
// - logTracePeerChange in poll.go // - logTracePeerChange in poll.go.
func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange { func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
ret := tailcfg.PeerChange{ ret := tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID), NodeID: tailcfg.NodeID(node.ID),

View file

@ -6,6 +6,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@ -366,3 +367,110 @@ func TestPeerChangeFromMapRequest(t *testing.T) {
}) })
} }
} }
func TestApplyPeerChange(t *testing.T) {
tests := []struct {
name string
nodeBefore Node
change *tailcfg.PeerChange
want Node
}{
{
name: "hostinfo-and-netinfo-not-exists",
nodeBefore: Node{},
change: &tailcfg.PeerChange{
DERPRegion: 1,
},
want: Node{
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 1,
},
},
},
},
{
name: "hostinfo-netinfo-not-exists",
nodeBefore: Node{
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test",
},
},
change: &tailcfg.PeerChange{
DERPRegion: 3,
},
want: Node{
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test",
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 3,
},
},
},
},
{
name: "hostinfo-netinfo-exists-derp-set",
nodeBefore: Node{
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test",
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 999,
},
},
},
change: &tailcfg.PeerChange{
DERPRegion: 2,
},
want: Node{
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test",
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 2,
},
},
},
},
{
name: "endpoints-not-set",
nodeBefore: Node{},
change: &tailcfg.PeerChange{
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("8.8.8.8:88"),
},
},
want: Node{
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("8.8.8.8:88"),
},
},
},
{
name: "endpoints-set",
nodeBefore: Node{
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("6.6.6.6:66"),
},
},
change: &tailcfg.PeerChange{
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("8.8.8.8:88"),
},
},
want: Node{
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("8.8.8.8:88"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.nodeBefore.ApplyPeerChange(tt.change)
if diff := cmp.Diff(tt.want, tt.nodeBefore, util.Comparers...); diff != "" {
t.Errorf("Patch unexpected result (-want +got):\n%s", diff)
}
})
}
}

View file

@ -2,7 +2,6 @@ package types
import ( import (
"strconv" "strconv"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
@ -25,9 +24,10 @@ func (n *User) TailscaleUser() *tailcfg.User {
ID: tailcfg.UserID(n.ID), ID: tailcfg.UserID(n.ID),
LoginName: n.Name, LoginName: n.Name,
DisplayName: n.Name, DisplayName: n.Name,
// TODO(kradalby): See if we can fill in Gravatar here
ProfilePicURL: "", ProfilePicURL: "",
Logins: []tailcfg.LoginID{}, Logins: []tailcfg.LoginID{},
Created: time.Time{}, Created: n.CreatedAt,
} }
return &user return &user
@ -38,6 +38,7 @@ func (n *User) TailscaleLogin() *tailcfg.Login {
ID: tailcfg.LoginID(n.ID), ID: tailcfg.LoginID(n.ID),
LoginName: n.Name, LoginName: n.Name,
DisplayName: n.Name, DisplayName: n.Name,
// TODO(kradalby): See if we can fill in Gravatar here
ProfilePicURL: "", ProfilePicURL: "",
} }

View file

@ -15,6 +15,10 @@ var IPComparer = cmp.Comparer(func(x, y netip.Addr) bool {
return x.Compare(y) == 0 return x.Compare(y) == 0
}) })
var AddrPortComparer = cmp.Comparer(func(x, y netip.AddrPort) bool {
return x == y
})
var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool { var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool {
return x.String() == y.String() return x.String() == y.String()
}) })
@ -28,5 +32,5 @@ var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
}) })
var Comparers []cmp.Option = []cmp.Option{ var Comparers []cmp.Option = []cmp.Option{
IPComparer, PrefixComparer, MkeyComparer, NkeyComparer, DkeyComparer, IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer,
} }

View file

@ -265,6 +265,8 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
}) })
@ -325,6 +327,8 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
}) })

View file

@ -53,6 +53,8 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
}) })
@ -90,6 +92,8 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
}) })

View file

@ -1240,7 +1240,7 @@ func TestNodeRenameCommand(t *testing.T) {
assert.Contains(t, listAll[4].GetGivenName(), "node-5") assert.Contains(t, listAll[4].GetGivenName(), "node-5")
for idx := 0; idx < 3; idx++ { for idx := 0; idx < 3; idx++ {
_, err := headscale.Execute( res, err := headscale.Execute(
[]string{ []string{
"headscale", "headscale",
"nodes", "nodes",
@ -1251,6 +1251,8 @@ func TestNodeRenameCommand(t *testing.T) {
}, },
) )
assert.Nil(t, err) assert.Nil(t, err)
assert.Contains(t, res, "Node renamed")
} }
var listAllAfterRename []v1.Node var listAllAfterRename []v1.Node

View file

@ -33,20 +33,23 @@ func TestDERPServerScenario(t *testing.T) {
defer scenario.Shutdown() defer scenario.Shutdown()
spec := map[string]int{ spec := map[string]int{
"user1": len(MustTestVersions), "user1": 10,
// "user1": len(MustTestVersions),
} }
headscaleConfig := map[string]string{} headscaleConfig := map[string]string{
headscaleConfig["HEADSCALE_DERP_URLS"] = "" "HEADSCALE_DERP_URLS": "",
headscaleConfig["HEADSCALE_DERP_SERVER_ENABLED"] = "true" "HEADSCALE_DERP_SERVER_ENABLED": "true",
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_ID"] = "999" "HEADSCALE_DERP_SERVER_REGION_ID": "999",
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_CODE"] = "headscale" "HEADSCALE_DERP_SERVER_REGION_CODE": "headscale",
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP" "HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP",
headscaleConfig["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478" "HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478",
headscaleConfig["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key" "HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key",
// Envknob for enabling DERP debug logs // Envknob for enabling DERP debug logs
headscaleConfig["DERP_DEBUG_LOGS"] = "true" "DERP_DEBUG_LOGS": "true",
headscaleConfig["DERP_PROBER_DEBUG_LOGS"] = "true" "DERP_PROBER_DEBUG_LOGS": "true",
}
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnv(
spec, spec,
@ -67,6 +70,8 @@ func TestDERPServerScenario(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
allHostnames, err := scenario.ListTailscaleClientsFQDNs() allHostnames, err := scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err) assertNoErrListFQDN(t, err)

View file

@ -26,12 +26,34 @@ func TestPingAllByIP(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
// TODO(kradalby): it does not look like the user thing works, only second
// get created? maybe only when many?
spec := map[string]int{ spec := map[string]int{
"user1": len(MustTestVersions), "user1": len(MustTestVersions),
"user2": len(MustTestVersions), "user2": len(MustTestVersions),
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyip")) headscaleConfig := map[string]string{
"HEADSCALE_DERP_URLS": "",
"HEADSCALE_DERP_SERVER_ENABLED": "true",
"HEADSCALE_DERP_SERVER_REGION_ID": "999",
"HEADSCALE_DERP_SERVER_REGION_CODE": "headscale",
"HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP",
"HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478",
"HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key",
// Envknob for enabling DERP debug logs
"DERP_DEBUG_LOGS": "true",
"DERP_PROBER_DEBUG_LOGS": "true",
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{},
hsic.WithTestName("pingallbyip"),
hsic.WithConfigEnv(headscaleConfig),
hsic.WithTLS(),
hsic.WithHostnameAsServerURL(),
)
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
@ -43,6 +65,46 @@ func TestPingAllByIP(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
func TestPingAllByIPPublicDERP(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario()
assertNoErr(t, err)
defer scenario.Shutdown()
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{},
hsic.WithTestName("pingallbyippubderp"),
)
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
}) })
@ -73,6 +135,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
clientIPs := make(map[TailscaleClient][]netip.Addr) clientIPs := make(map[TailscaleClient][]netip.Addr)
for _, client := range allClients { for _, client := range allClients {
ips, err := client.IPs() ips, err := client.IPs()
@ -112,6 +176,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
allClients, err = scenario.ListTailscaleClients() allClients, err = scenario.ListTailscaleClients()
assertNoErrListClients(t, err) assertNoErrListClients(t, err)
@ -263,6 +329,8 @@ func TestPingAllByHostname(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
allHostnames, err := scenario.ListTailscaleClientsFQDNs() allHostnames, err := scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err) assertNoErrListFQDN(t, err)
@ -320,9 +388,13 @@ func TestTaildrop(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err) t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err)
} }
} }
curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"} curlCommand := []string{
"curl",
"--unix-socket",
"/var/run/tailscale/tailscaled.sock",
"http://local-tailscaled.sock/localapi/v0/file-targets",
}
err = retry(10, 1*time.Second, func() error { err = retry(10, 1*time.Second, func() error {
result, _, err := client.Execute(curlCommand) result, _, err := client.Execute(curlCommand)
if err != nil { if err != nil {
@ -339,13 +411,23 @@ func TestTaildrop(t *testing.T) {
for _, ft := range fts { for _, ft := range fts {
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name) ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
} }
return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr) return fmt.Errorf(
"client %s does not have all its peers as FileTargets, got %d, want: %d\n%s",
client.Hostname(),
len(fts),
len(allClients)-1,
ftStr,
)
} }
return err return err
}) })
if err != nil { if err != nil {
t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err) t.Errorf(
"failed to query localapi for filetarget on %s, err: %s",
client.Hostname(),
err,
)
} }
} }
@ -457,6 +539,8 @@ func TestResolveMagicDNS(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
// Poor mans cache // Poor mans cache
_, err = scenario.ListTailscaleClientsFQDNs() _, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err) assertNoErrListFQDN(t, err)
@ -525,6 +609,8 @@ func TestExpireNode(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
}) })
@ -546,7 +632,7 @@ func TestExpireNode(t *testing.T) {
// TODO(kradalby): This is Headscale specific and would not play nicely // TODO(kradalby): This is Headscale specific and would not play nicely
// with other implementations of the ControlServer interface // with other implementations of the ControlServer interface
result, err := headscale.Execute([]string{ result, err := headscale.Execute([]string{
"headscale", "nodes", "expire", "--identifier", "0", "--output", "json", "headscale", "nodes", "expire", "--identifier", "1", "--output", "json",
}) })
assertNoErr(t, err) assertNoErr(t, err)
@ -577,16 +663,38 @@ func TestExpireNode(t *testing.T) {
assertNotNil(t, peerStatus.Expired) assertNotNil(t, peerStatus.Expired)
assert.NotNil(t, peerStatus.KeyExpiry) assert.NotNil(t, peerStatus.KeyExpiry)
t.Logf("node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) t.Logf(
"node %q should have a key expire before %s, was %s",
peerStatus.HostName,
now.String(),
peerStatus.KeyExpiry,
)
if peerStatus.KeyExpiry != nil { if peerStatus.KeyExpiry != nil {
assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) assert.Truef(
t,
peerStatus.KeyExpiry.Before(now),
"node %q should have a key expire before %s, was %s",
peerStatus.HostName,
now.String(),
peerStatus.KeyExpiry,
)
} }
assert.Truef(t, peerStatus.Expired, "node %q should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired) assert.Truef(
t,
peerStatus.Expired,
"node %q should be expired, expired is %v",
peerStatus.HostName,
peerStatus.Expired,
)
_, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()}) _, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()})
if !strings.Contains(stderr, "node key has expired") { if !strings.Contains(stderr, "node key has expired") {
t.Errorf("expected to be unable to ping expired host %q from %q", node.GetName(), client.Hostname()) t.Errorf(
"expected to be unable to ping expired host %q from %q",
node.GetName(),
client.Hostname(),
)
} }
} else { } else {
t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey) t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey)
@ -598,7 +706,7 @@ func TestExpireNode(t *testing.T) {
// NeedsLogin means that the node has understood that it is no longer // NeedsLogin means that the node has understood that it is no longer
// valid. // valid.
assert.Equal(t, "NeedsLogin", status.BackendState) assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName)
} }
} }
} }
@ -627,6 +735,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
}) })
@ -691,7 +801,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
assert.Truef( assert.Truef(
t, t,
lastSeen.After(lastSeenThreshold), lastSeen.After(lastSeenThreshold),
"lastSeen (%v) was not %s after the threshold (%v)", "node (%s) lastSeen (%v) was not %s after the threshold (%v)",
node.GetName(),
lastSeen, lastSeen,
keepAliveInterval, keepAliveInterval,
lastSeenThreshold, lastSeenThreshold,

View file

@ -110,8 +110,8 @@ func DefaultConfigEnv() map[string]string {
return map[string]string{ return map[string]string{
"HEADSCALE_LOG_LEVEL": "trace", "HEADSCALE_LOG_LEVEL": "trace",
"HEADSCALE_ACL_POLICY_PATH": "", "HEADSCALE_ACL_POLICY_PATH": "",
"HEADSCALE_DB_TYPE": "sqlite3", "HEADSCALE_DATABASE_TYPE": "sqlite",
"HEADSCALE_DB_PATH": "/tmp/integration_test_db.sqlite3", "HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3",
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m", "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
"HEADSCALE_NODE_UPDATE_CHECK_INTERVAL": "10s", "HEADSCALE_NODE_UPDATE_CHECK_INTERVAL": "10s",
"HEADSCALE_IP_PREFIXES": "fd7a:115c:a1e0::/48 100.64.0.0/10", "HEADSCALE_IP_PREFIXES": "fd7a:115c:a1e0::/48 100.64.0.0/10",

View file

@ -9,10 +9,15 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"tailscale.com/types/ipproto"
"tailscale.com/wgengine/filter"
) )
// This test is both testing the routes command and the propagation of // This test is both testing the routes command and the propagation of
@ -83,9 +88,9 @@ func TestEnablingRoutes(t *testing.T) {
assert.Len(t, routes, 3) assert.Len(t, routes, 3)
for _, route := range routes { for _, route := range routes {
assert.Equal(t, route.GetAdvertised(), true) assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, route.GetEnabled(), false) assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, route.GetIsPrimary(), false) assert.Equal(t, false, route.GetIsPrimary())
} }
// Verify that no routes has been sent to the client, // Verify that no routes has been sent to the client,
@ -130,9 +135,9 @@ func TestEnablingRoutes(t *testing.T) {
assert.Len(t, enablingRoutes, 3) assert.Len(t, enablingRoutes, 3)
for _, route := range enablingRoutes { for _, route := range enablingRoutes {
assert.Equal(t, route.GetAdvertised(), true) assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, route.GetEnabled(), true) assert.Equal(t, true, route.GetEnabled())
assert.Equal(t, route.GetIsPrimary(), true) assert.Equal(t, true, route.GetIsPrimary())
} }
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
@ -186,6 +191,8 @@ func TestEnablingRoutes(t *testing.T) {
}) })
assertNoErr(t, err) assertNoErr(t, err)
time.Sleep(5 * time.Second)
var disablingRoutes []*v1.Route var disablingRoutes []*v1.Route
err = executeAndUnmarshal( err = executeAndUnmarshal(
headscale, headscale,
@ -204,16 +211,14 @@ func TestEnablingRoutes(t *testing.T) {
assert.Equal(t, true, route.GetAdvertised()) assert.Equal(t, true, route.GetAdvertised())
if route.GetId() == routeToBeDisabled.GetId() { if route.GetId() == routeToBeDisabled.GetId() {
assert.Equal(t, route.GetEnabled(), false) assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, route.GetIsPrimary(), false) assert.Equal(t, false, route.GetIsPrimary())
} else { } else {
assert.Equal(t, route.GetEnabled(), true) assert.Equal(t, true, route.GetEnabled())
assert.Equal(t, route.GetIsPrimary(), true) assert.Equal(t, true, route.GetIsPrimary())
} }
} }
time.Sleep(5 * time.Second)
// Verify that the clients can see the new routes // Verify that the clients can see the new routes
for _, client := range allClients { for _, client := range allClients {
status, err := client.Status() status, err := client.Status()
@ -289,7 +294,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// advertise HA route on node 1 and 2 // advertise HA route on node 1 and 2
// ID 1 will be primary // ID 1 will be primary
// ID 2 will be secondary // ID 2 will be secondary
for _, client := range allClients { for _, client := range allClients[:2] {
status, err := client.Status() status, err := client.Status()
assertNoErr(t, err) assertNoErr(t, err)
@ -301,6 +306,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
} }
_, _, err = client.Execute(command) _, _, err = client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err) assertNoErrf(t, "failed to advertise route: %s", err)
} else {
t.Fatalf("failed to find route for Node %s (id: %s)", status.Self.HostName, status.Self.ID)
} }
} }
@ -323,6 +330,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
t.Logf("initial routes %#v", routes)
for _, route := range routes { for _, route := range routes {
assert.Equal(t, true, route.GetAdvertised()) assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled()) assert.Equal(t, false, route.GetEnabled())
@ -639,6 +648,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
assert.Len(t, routesAfterDisabling1, 2) assert.Len(t, routesAfterDisabling1, 2)
t.Logf("routes after disabling1 %#v", routesAfterDisabling1)
// Node 1 is not primary // Node 1 is not primary
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised()) assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled()) assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled())
@ -778,3 +789,413 @@ func TestHASubnetRouterFailover(t *testing.T) {
) )
} }
} }
func TestEnableDisableAutoApprovedRoute(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
expectedRoutes := "172.0.0.0/24"
user := "enable-disable-routing"
scenario, err := NewScenario()
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
spec := map[string]int{
user: 1,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:approve"})}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy(
&policy.ACLPolicy{
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
TagOwners: map[string][]string{
"tag:approve": {user},
},
AutoApprovers: policy.AutoApprovers{
Routes: map[string][]string{
expectedRoutes: {"tag:approve"},
},
},
},
))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
subRouter1 := allClients[0]
// Initially advertise route
command := []string{
"tailscale",
"set",
"--advertise-routes=" + expectedRoutes,
}
_, _, err = subRouter1.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
time.Sleep(10 * time.Second)
var routes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routes,
)
assertNoErr(t, err)
assert.Len(t, routes, 1)
// All routes should be auto approved and enabled
assert.Equal(t, true, routes[0].GetAdvertised())
assert.Equal(t, true, routes[0].GetEnabled())
assert.Equal(t, true, routes[0].GetIsPrimary())
// Stop advertising route
command = []string{
"tailscale",
"set",
"--advertise-routes=",
}
_, _, err = subRouter1.Execute(command)
assertNoErrf(t, "failed to remove advertised route: %s", err)
time.Sleep(10 * time.Second)
var notAdvertisedRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&notAdvertisedRoutes,
)
assertNoErr(t, err)
assert.Len(t, notAdvertisedRoutes, 1)
// Route is no longer advertised
assert.Equal(t, false, notAdvertisedRoutes[0].GetAdvertised())
assert.Equal(t, false, notAdvertisedRoutes[0].GetEnabled())
assert.Equal(t, true, notAdvertisedRoutes[0].GetIsPrimary())
// Advertise route again
command = []string{
"tailscale",
"set",
"--advertise-routes=" + expectedRoutes,
}
_, _, err = subRouter1.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
time.Sleep(10 * time.Second)
var reAdvertisedRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&reAdvertisedRoutes,
)
assertNoErr(t, err)
assert.Len(t, reAdvertisedRoutes, 1)
// All routes should be auto approved and enabled
assert.Equal(t, true, reAdvertisedRoutes[0].GetAdvertised())
assert.Equal(t, true, reAdvertisedRoutes[0].GetEnabled())
assert.Equal(t, true, reAdvertisedRoutes[0].GetIsPrimary())
}
// TestSubnetRouteACL verifies that Subnet routes are distributed
// as expected when ACLs are activated.
// It implements the issue from
// https://github.com/juanfont/headscale/issues/1604
func TestSubnetRouteACL(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "subnet-route-acl"
scenario, err := NewScenario()
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
spec := map[string]int{
user: 2,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy(
&policy.ACLPolicy{
Groups: policy.Groups{
"group:admins": {user},
},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"group:admins"},
Destinations: []string{"group:admins:*"},
},
{
Action: "accept",
Sources: []string{"group:admins"},
Destinations: []string{"10.33.0.0/16:*"},
},
// {
// Action: "accept",
// Sources: []string{"group:admins"},
// Destinations: []string{"0.0.0.0/0:*"},
// },
},
},
))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
expectedRoutes := map[string]string{
"1": "10.33.0.0/16",
}
// Sort nodes by ID
sort.SliceStable(allClients, func(i, j int) bool {
statusI, err := allClients[i].Status()
if err != nil {
return false
}
statusJ, err := allClients[j].Status()
if err != nil {
return false
}
return statusI.Self.ID < statusJ.Self.ID
})
subRouter1 := allClients[0]
client := allClients[1]
// advertise HA route on node 1 and 2
// ID 1 will be primary
// ID 2 will be secondary
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
if route, ok := expectedRoutes[string(status.Self.ID)]; ok {
command := []string{
"tailscale",
"set",
"--advertise-routes=" + route,
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
}
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
var routes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routes,
)
assertNoErr(t, err)
assert.Len(t, routes, 1)
for _, route := range routes {
assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary())
}
// Verify that no routes has been sent to the client,
// they are not yet enabled.
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
assert.Nil(t, peerStatus.PrimaryRoutes)
}
}
// Enable all routes
for _, route := range routes {
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
strconv.Itoa(int(route.GetId())),
})
assertNoErr(t, err)
}
time.Sleep(5 * time.Second)
var enablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&enablingRoutes,
)
assertNoErr(t, err)
assert.Len(t, enablingRoutes, 1)
// Node 1 has active route
assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
assert.Equal(t, true, enablingRoutes[0].GetEnabled())
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary())
// Verify that the client has routes from the primary machine
srs1, _ := subRouter1.Status()
clientStatus, err := client.Status()
assertNoErr(t, err)
srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
t.Logf("subnet1 has following routes: %v", srs1PeerStatus.PrimaryRoutes.AsSlice())
assert.Len(t, srs1PeerStatus.PrimaryRoutes.AsSlice(), 1)
assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
)
clientNm, err := client.Netmap()
assertNoErr(t, err)
wantClientFilter := []filter.Match{
{
IPProto: []ipproto.Proto{
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
},
Srcs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
netip.MustParsePrefix("100.64.0.2/32"),
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
},
Dsts: []filter.NetPortRange{
{
Net: netip.MustParsePrefix("100.64.0.2/32"),
Ports: filter.PortRange{0, 0xffff},
},
{
Net: netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
Ports: filter.PortRange{0, 0xffff},
},
},
Caps: []filter.CapMatch{},
},
}
if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.PrefixComparer); diff != "" {
t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)
}
subnetNm, err := subRouter1.Netmap()
assertNoErr(t, err)
wantSubnetFilter := []filter.Match{
{
IPProto: []ipproto.Proto{
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
},
Srcs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
netip.MustParsePrefix("100.64.0.2/32"),
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
},
Dsts: []filter.NetPortRange{
{
Net: netip.MustParsePrefix("100.64.0.1/32"),
Ports: filter.PortRange{0, 0xffff},
},
{
Net: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
Ports: filter.PortRange{0, 0xffff},
},
},
Caps: []filter.CapMatch{},
},
{
IPProto: []ipproto.Proto{
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
},
Srcs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
netip.MustParsePrefix("100.64.0.2/32"),
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
},
Dsts: []filter.NetPortRange{
{
Net: netip.MustParsePrefix("10.33.0.0/16"),
Ports: filter.PortRange{0, 0xffff},
},
},
Caps: []filter.CapMatch{},
},
}
if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.PrefixComparer); diff != "" {
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
}
}

View file

@ -56,8 +56,8 @@ var (
"1.44": true, // CapVer: 63 "1.44": true, // CapVer: 63
"1.42": true, // CapVer: 61 "1.42": true, // CapVer: 61
"1.40": true, // CapVer: 61 "1.40": true, // CapVer: 61
"1.38": true, // CapVer: 58 "1.38": true, // Oldest supported version, CapVer: 58
"1.36": true, // Oldest supported version, CapVer: 56 "1.36": false, // CapVer: 56
"1.34": false, // CapVer: 51 "1.34": false, // CapVer: 51
"1.32": false, // CapVer: 46 "1.32": false, // CapVer: 46
"1.30": false, "1.30": false,

View file

@ -7,6 +7,8 @@ import (
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/net/netcheck"
"tailscale.com/types/netmap"
) )
// nolint // nolint
@ -26,6 +28,8 @@ type TailscaleClient interface {
IPs() ([]netip.Addr, error) IPs() ([]netip.Addr, error)
FQDN() (string, error) FQDN() (string, error)
Status() (*ipnstate.Status, error) Status() (*ipnstate.Status, error)
Netmap() (*netmap.NetworkMap, error)
Netcheck() (*netcheck.Report, error)
WaitForNeedsLogin() error WaitForNeedsLogin() error
WaitForRunning() error WaitForRunning() error
WaitForPeers(expected int) error WaitForPeers(expected int) error

View file

@ -17,6 +17,8 @@ import (
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker" "github.com/ory/dockertest/v3/docker"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/net/netcheck"
"tailscale.com/types/netmap"
) )
const ( const (
@ -519,6 +521,53 @@ func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) {
return &status, err return &status, err
} }
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
// Only works with Tailscale 1.56.1 and newer.
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
command := []string{
"tailscale",
"debug",
"netmap",
}
result, stderr, err := t.Execute(command)
if err != nil {
fmt.Printf("stderr: %s\n", stderr)
return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err)
}
var nm netmap.NetworkMap
err = json.Unmarshal([]byte(result), &nm)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err)
}
return &nm, err
}
// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance.
func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) {
command := []string{
"tailscale",
"netcheck",
"--format=json",
}
result, stderr, err := t.Execute(command)
if err != nil {
fmt.Printf("stderr: %s\n", stderr)
return nil, fmt.Errorf("failed to execute tailscale debug netcheck command: %w", err)
}
var nm netcheck.Report
err = json.Unmarshal([]byte(result), &nm)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal tailscale netcheck: %w", err)
}
return &nm, err
}
// FQDN returns the FQDN as a string of the Tailscale instance. // FQDN returns the FQDN as a string of the Tailscale instance.
func (t *TailscaleInContainer) FQDN() (string, error) { func (t *TailscaleInContainer) FQDN() (string, error) {
if t.fqdn != "" { if t.fqdn != "" {
@ -623,12 +672,22 @@ func (t *TailscaleInContainer) WaitForPeers(expected int) error {
len(peers), len(peers),
) )
} else { } else {
// Verify that the peers of a given node is Online
// has a hostname and a DERP relay.
for _, peerKey := range peers { for _, peerKey := range peers {
peer := status.Peer[peerKey] peer := status.Peer[peerKey]
if !peer.Online { if !peer.Online {
return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName) return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)
} }
if peer.HostName == "" {
return fmt.Errorf("[%s] peer count correct, but %s does not have a Hostname", t.hostname, peer.HostName)
}
if peer.Relay == "" {
return fmt.Errorf("[%s] peer count correct, but %s does not have a DERP", t.hostname, peer.HostName)
}
} }
} }

View file

@ -7,6 +7,8 @@ import (
"time" "time"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"tailscale.com/util/cmpver"
) )
const ( const (
@ -83,7 +85,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts
for _, addr := range addrs { for _, addr := range addrs {
err := client.Ping(addr, opts...) err := client.Ping(addr, opts...)
if err != nil { if err != nil {
t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err) t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
} else { } else {
success++ success++
} }
@ -120,6 +122,148 @@ func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string)
return success return success
} }
// assertClientsState validates the status and netmap of a list of
// clients for the general case of all to all connectivity.
func assertClientsState(t *testing.T, clients []TailscaleClient) {
t.Helper()
for _, client := range clients {
assertValidStatus(t, client)
assertValidNetmap(t, client)
assertValidNetcheck(t, client)
}
}
// assertValidNetmap asserts that the netmap of a client has all
// the minimum required fields set to a known working config for
// the general case. Fields are checked on self, then all peers.
// This test is not suitable for ACL/partial connection tests.
// This test can only be run on clients from 1.56.1. It will
// automatically pass all clients below that and is safe to call
// for all versions.
func assertValidNetmap(t *testing.T, client TailscaleClient) {
t.Helper()
if cmpver.Compare("1.56.1", client.Version()) <= 0 ||
!strings.Contains(client.Hostname(), "unstable") ||
!strings.Contains(client.Hostname(), "head") {
return
}
netmap, err := client.Netmap()
if err != nil {
t.Fatalf("getting netmap for %q: %s", client.Hostname(), err)
}
assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
}
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
for _, peer := range netmap.Peers {
assert.NotEqualf(t, "127.3.3.40:0", peer.DERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.DERP())
assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
if hi := peer.Hostinfo(); hi.Valid() {
assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
// Netinfo is not always set
assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
if ni := hi.NetInfo(); ni.Valid() {
assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
}
}
assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
assert.Truef(t, *peer.Online(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
}
}
// assertValidStatus asserts that the status of a client has all
// the minimum required fields set to a known working config for
// the general case. Fields are checked on self, then all peers.
// This test is not suitable for ACL/partial connection tests.
func assertValidStatus(t *testing.T, client TailscaleClient) {
t.Helper()
status, err := client.Status()
if err != nil {
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
}
assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname())
// This seem to not appear until version 1.56
if status.Self.AllowedIPs != nil {
assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname())
}
assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname())
assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname())
assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname())
// This isnt really relevant for Self as it wont be in its own socket/wireguard.
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
// assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname())
for _, peer := range status.Peer {
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname())
// This seem to not appear until version 1.56
if peer.AllowedIPs != nil {
assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname())
}
// Addrs does not seem to appear in the status from peers.
// assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname())
assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname())
assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname())
assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname())
// TODO(kradalby): InEngine is only true when a proper tunnel is set up,
// there might be some interesting stuff to test here in the future.
// assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname())
}
}
func assertValidNetcheck(t *testing.T, client TailscaleClient) {
t.Helper()
report, err := client.Netcheck()
if err != nil {
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
}
assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname())
}
func isSelfClient(client TailscaleClient, addr string) bool { func isSelfClient(client TailscaleClient, addr string) bool {
if addr == client.Hostname() { if addr == client.Hostname() {
return true return true
@ -152,7 +296,7 @@ func isCI() bool {
} }
func dockertestMaxWait() time.Duration { func dockertestMaxWait() time.Duration {
wait := 60 * time.Second //nolint wait := 120 * time.Second //nolint
if isCI() { if isCI() {
wait = 300 * time.Second //nolint wait = 300 * time.Second //nolint