Compare commits
14 commits
6b635b3566
...
cc430af3f9
Author | SHA1 | Date | |
---|---|---|---|
cc430af3f9 | |||
|
c3257e2146 | ||
|
9047c09871 | ||
|
91bb85e7d2 | ||
|
94b30abf56 | ||
|
00e7550e76 | ||
|
83769ba715 | ||
|
cbf57e27a7 | ||
|
4ea12f472a | ||
|
b4210e2c90 | ||
|
a369d57a17 | ||
|
1e22f17f36 | ||
|
65376e2842 | ||
|
7e8bf4bfe5 |
52 changed files with 3272 additions and 1415 deletions
13
.github/ISSUE_TEMPLATE/bug_report.md
vendored
13
.github/ISSUE_TEMPLATE/bug_report.md
vendored
|
@ -50,3 +50,16 @@ instead of filing a bug report.
|
|||
## To Reproduce
|
||||
|
||||
<!-- Steps to reproduce the behavior. -->
|
||||
|
||||
## Logs and attachments
|
||||
|
||||
<!-- Please attach files with:
|
||||
- Client netmap dump (see below)
|
||||
- ACL configuration
|
||||
- Headscale configuration
|
||||
|
||||
Dump the netmap of tailscale clients:
|
||||
`tailscale debug netmap > DESCRIPTIVE_NAME.json`
|
||||
|
||||
Please provide information describing the netmap, which client, which headscale version etc.
|
||||
-->
|
||||
|
|
67
.github/workflows/test-integration-v2-TestEnableDisableAutoApprovedRoute.yaml
vendored
Normal file
67
.github/workflows/test-integration-v2-TestEnableDisableAutoApprovedRoute.yaml
vendored
Normal file
|
@ -0,0 +1,67 @@
|
|||
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
|
||||
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
|
||||
|
||||
name: Integration Test v2 - TestEnableDisableAutoApprovedRoute
|
||||
|
||||
on: [pull_request]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
TestEnableDisableAutoApprovedRoute:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- uses: DeterminateSystems/nix-installer-action@main
|
||||
- uses: DeterminateSystems/magic-nix-cache-action@main
|
||||
- uses: satackey/action-docker-layer-caching@main
|
||||
continue-on-error: true
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v34
|
||||
with:
|
||||
files: |
|
||||
*.nix
|
||||
go.*
|
||||
**/*.go
|
||||
integration_test/
|
||||
config-example.yaml
|
||||
|
||||
- name: Run TestEnableDisableAutoApprovedRoute
|
||||
uses: Wandalen/wretry.action@master
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
attempt_limit: 5
|
||||
command: |
|
||||
nix develop --command -- docker run \
|
||||
--tty --rm \
|
||||
--volume ~/.cache/hs-integration-go:/go \
|
||||
--name headscale-test-suite \
|
||||
--volume $PWD:$PWD -w $PWD/integration \
|
||||
--volume /var/run/docker.sock:/var/run/docker.sock \
|
||||
--volume $PWD/control_logs:/tmp/control \
|
||||
golang:1 \
|
||||
go run gotest.tools/gotestsum@latest -- ./... \
|
||||
-failfast \
|
||||
-timeout 120m \
|
||||
-parallel 1 \
|
||||
-run "^TestEnableDisableAutoApprovedRoute$"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
name: logs
|
||||
path: "control_logs/*.log"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
name: pprof
|
||||
path: "control_logs/*.pprof.tar"
|
67
.github/workflows/test-integration-v2-TestPingAllByIPPublicDERP.yaml
vendored
Normal file
67
.github/workflows/test-integration-v2-TestPingAllByIPPublicDERP.yaml
vendored
Normal file
|
@ -0,0 +1,67 @@
|
|||
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
|
||||
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
|
||||
|
||||
name: Integration Test v2 - TestPingAllByIPPublicDERP
|
||||
|
||||
on: [pull_request]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
TestPingAllByIPPublicDERP:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- uses: DeterminateSystems/nix-installer-action@main
|
||||
- uses: DeterminateSystems/magic-nix-cache-action@main
|
||||
- uses: satackey/action-docker-layer-caching@main
|
||||
continue-on-error: true
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v34
|
||||
with:
|
||||
files: |
|
||||
*.nix
|
||||
go.*
|
||||
**/*.go
|
||||
integration_test/
|
||||
config-example.yaml
|
||||
|
||||
- name: Run TestPingAllByIPPublicDERP
|
||||
uses: Wandalen/wretry.action@master
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
attempt_limit: 5
|
||||
command: |
|
||||
nix develop --command -- docker run \
|
||||
--tty --rm \
|
||||
--volume ~/.cache/hs-integration-go:/go \
|
||||
--name headscale-test-suite \
|
||||
--volume $PWD:$PWD -w $PWD/integration \
|
||||
--volume /var/run/docker.sock:/var/run/docker.sock \
|
||||
--volume $PWD/control_logs:/tmp/control \
|
||||
golang:1 \
|
||||
go run gotest.tools/gotestsum@latest -- ./... \
|
||||
-failfast \
|
||||
-timeout 120m \
|
||||
-parallel 1 \
|
||||
-run "^TestPingAllByIPPublicDERP$"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
name: logs
|
||||
path: "control_logs/*.log"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
name: pprof
|
||||
path: "control_logs/*.pprof.tar"
|
67
.github/workflows/test-integration-v2-TestSubnetRouteACL.yaml
vendored
Normal file
67
.github/workflows/test-integration-v2-TestSubnetRouteACL.yaml
vendored
Normal file
|
@ -0,0 +1,67 @@
|
|||
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
|
||||
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
|
||||
|
||||
name: Integration Test v2 - TestSubnetRouteACL
|
||||
|
||||
on: [pull_request]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
TestSubnetRouteACL:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- uses: DeterminateSystems/nix-installer-action@main
|
||||
- uses: DeterminateSystems/magic-nix-cache-action@main
|
||||
- uses: satackey/action-docker-layer-caching@main
|
||||
continue-on-error: true
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v34
|
||||
with:
|
||||
files: |
|
||||
*.nix
|
||||
go.*
|
||||
**/*.go
|
||||
integration_test/
|
||||
config-example.yaml
|
||||
|
||||
- name: Run TestSubnetRouteACL
|
||||
uses: Wandalen/wretry.action@master
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
attempt_limit: 5
|
||||
command: |
|
||||
nix develop --command -- docker run \
|
||||
--tty --rm \
|
||||
--volume ~/.cache/hs-integration-go:/go \
|
||||
--name headscale-test-suite \
|
||||
--volume $PWD:$PWD -w $PWD/integration \
|
||||
--volume /var/run/docker.sock:/var/run/docker.sock \
|
||||
--volume $PWD/control_logs:/tmp/control \
|
||||
golang:1 \
|
||||
go run gotest.tools/gotestsum@latest -- ./... \
|
||||
-failfast \
|
||||
-timeout 120m \
|
||||
-parallel 1 \
|
||||
-run "^TestSubnetRouteACL$"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
name: logs
|
||||
path: "control_logs/*.log"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
name: pprof
|
||||
path: "control_logs/*.pprof.tar"
|
28
CHANGELOG.md
28
CHANGELOG.md
|
@ -26,7 +26,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/
|
|||
- Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473)
|
||||
- API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553)
|
||||
- Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611)
|
||||
- The latest supported client is 1.36
|
||||
- The latest supported client is 1.38
|
||||
- Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564)
|
||||
- If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url.
|
||||
- Embedded DERP server requires a private key [#1611](https://github.com/juanfont/headscale/pull/1611)
|
||||
|
@ -34,17 +34,21 @@ after improving the test harness as part of adopting [#1460](https://github.com/
|
|||
|
||||
### Changes
|
||||
|
||||
Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644)
|
||||
Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
|
||||
Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287)
|
||||
SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
|
||||
State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
|
||||
Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460)
|
||||
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
|
||||
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
|
||||
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
|
||||
Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259)
|
||||
Add `oidc.groups_claim`, `oidc.email_claim`, and `oidc.username_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)
|
||||
- Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644)
|
||||
- Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
|
||||
- SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
|
||||
- State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
|
||||
- Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460)
|
||||
- Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
|
||||
- Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
|
||||
- Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
|
||||
- Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259)
|
||||
- Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565)
|
||||
- Change the structure of database configuration, see [config-example.yaml](./config-example.yaml) for the new structure. [#1700](https://github.com/juanfont/headscale/pull/1700)
|
||||
- Old structure is now considered deprecated and will be removed in the future.
|
||||
- Adds additional configuration for PostgreSQL for setting max open, idle conection and idle connection lifetime.
|
||||
- Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287)
|
||||
- Add `oidc.groups_claim`, `oidc.email_claim`, and `oidc.username_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)
|
||||
|
||||
## 0.22.3 (2023-05-12)
|
||||
|
||||
|
|
|
@ -6,25 +6,11 @@ import (
|
|||
|
||||
"github.com/efekarakus/termcolor"
|
||||
"github.com/juanfont/headscale/cmd/headscale/cli"
|
||||
"github.com/pkg/profile"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
|
||||
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
|
||||
err := os.MkdirAll(profilePath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to create profiling directory")
|
||||
}
|
||||
|
||||
defer profile.Start(profile.ProfilePath(profilePath)).Stop()
|
||||
} else {
|
||||
defer profile.Start().Stop()
|
||||
}
|
||||
}
|
||||
|
||||
var colors bool
|
||||
switch l := termcolor.SupportLevel(os.Stderr); l {
|
||||
case termcolor.Level16M:
|
||||
|
|
|
@ -58,8 +58,10 @@ func (*Suite) TestConfigFileLoading(c *check.C) {
|
|||
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
|
||||
c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
|
||||
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
|
||||
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
|
||||
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite")
|
||||
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
|
||||
c.Assert(viper.GetString("database.type"), check.Equals, "sqlite")
|
||||
c.Assert(viper.GetString("database.sqlite.path"), check.Equals, "/var/lib/headscale/db.sqlite")
|
||||
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
||||
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
|
||||
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
||||
|
@ -101,7 +103,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
|
|||
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
|
||||
c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
|
||||
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
|
||||
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
|
||||
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite")
|
||||
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
|
||||
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
||||
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
|
||||
|
|
|
@ -94,6 +94,16 @@ derp:
|
|||
#
|
||||
private_key_path: /var/lib/headscale/derp_server_private.key
|
||||
|
||||
# This flag can be used, so the DERP map entry for the embedded DERP server is not written automatically,
|
||||
# it enables the creation of your very own DERP map entry using a locally available file with the parameter DERP.paths
|
||||
# If you enable the DERP server and set this to false, it is required to add the DERP server to the DERP map using DERP.paths
|
||||
automatically_add_embedded_derp_region: true
|
||||
|
||||
# For better connection stability (especially when using an Exit-Node and DNS is not working),
|
||||
# it is possible to optionall add the public IPv4 and IPv6 address to the Derp-Map using:
|
||||
ipv4: 1.2.3.4
|
||||
ipv6: 2001:db8::1
|
||||
|
||||
# List of externally available DERP maps encoded in JSON
|
||||
urls:
|
||||
- https://controlplane.tailscale.com/derpmap/default
|
||||
|
@ -128,24 +138,28 @@ ephemeral_node_inactivity_timeout: 30m
|
|||
# In case of doubts, do not touch the default 10s.
|
||||
node_update_check_interval: 10s
|
||||
|
||||
# SQLite config
|
||||
db_type: sqlite3
|
||||
database:
|
||||
type: sqlite
|
||||
|
||||
# For production:
|
||||
db_path: /var/lib/headscale/db.sqlite
|
||||
# SQLite config
|
||||
sqlite:
|
||||
path: /var/lib/headscale/db.sqlite
|
||||
|
||||
# # Postgres config
|
||||
# If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank.
|
||||
# db_type: postgres
|
||||
# db_host: localhost
|
||||
# db_port: 5432
|
||||
# db_name: headscale
|
||||
# db_user: foo
|
||||
# db_pass: bar
|
||||
# postgres:
|
||||
# # If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank.
|
||||
# host: localhost
|
||||
# port: 5432
|
||||
# name: headscale
|
||||
# user: foo
|
||||
# pass: bar
|
||||
# max_open_conns: 10
|
||||
# max_idle_conns: 10
|
||||
# conn_max_idle_time_secs: 3600
|
||||
|
||||
# If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need
|
||||
# in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1.
|
||||
# db_ssl: false
|
||||
# # If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need
|
||||
# # in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1.
|
||||
# ssl: false
|
||||
|
||||
### TLS configuration
|
||||
#
|
||||
|
|
|
@ -18,6 +18,7 @@ You can set these using the Windows Registry Editor:
|
|||
Or via the following Powershell commands (right click Powershell icon and select "Run as administrator"):
|
||||
|
||||
```
|
||||
New-Item -Path "HKLM:\SOFTWARE\Tailscale IPN"
|
||||
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always
|
||||
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value https://YOUR-HEADSCALE-URL
|
||||
```
|
||||
|
|
107
hscontrol/app.go
107
hscontrol/app.go
|
@ -12,7 +12,6 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
@ -33,6 +32,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/patrickmn/go-cache"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
"github.com/pkg/profile"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
zl "github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -48,6 +48,7 @@ import (
|
|||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/reflection"
|
||||
"google.golang.org/grpc/status"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
|
@ -61,7 +62,7 @@ var (
|
|||
"unknown value for Lets Encrypt challenge type",
|
||||
)
|
||||
errEmptyInitialDERPMap = errors.New(
|
||||
"initial DERPMap is empty, Headscale requries at least one entry",
|
||||
"initial DERPMap is empty, Headscale requires at least one entry",
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -116,37 +117,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
|
||||
}
|
||||
|
||||
var dbString string
|
||||
switch cfg.DBtype {
|
||||
case db.Postgres:
|
||||
dbString = fmt.Sprintf(
|
||||
"host=%s dbname=%s user=%s",
|
||||
cfg.DBhost,
|
||||
cfg.DBname,
|
||||
cfg.DBuser,
|
||||
)
|
||||
|
||||
if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil {
|
||||
if !sslEnabled {
|
||||
dbString += " sslmode=disable"
|
||||
}
|
||||
} else {
|
||||
dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl)
|
||||
}
|
||||
|
||||
if cfg.DBport != 0 {
|
||||
dbString += fmt.Sprintf(" port=%d", cfg.DBport)
|
||||
}
|
||||
|
||||
if cfg.DBpass != "" {
|
||||
dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
|
||||
}
|
||||
case db.Sqlite:
|
||||
dbString = cfg.DBpath
|
||||
default:
|
||||
return nil, errUnsupportedDatabase
|
||||
}
|
||||
|
||||
registrationCache := cache.New(
|
||||
registerCacheExpiration,
|
||||
registerCacheCleanup,
|
||||
|
@ -154,8 +124,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
|
||||
app := Headscale{
|
||||
cfg: cfg,
|
||||
dbType: cfg.DBtype,
|
||||
dbString: dbString,
|
||||
noisePrivateKey: noisePrivateKey,
|
||||
registrationCache: registrationCache,
|
||||
pollNetMapStreamWG: sync.WaitGroup{},
|
||||
|
@ -163,9 +131,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
}
|
||||
|
||||
database, err := db.NewHeadscaleDatabase(
|
||||
cfg.DBtype,
|
||||
dbString,
|
||||
app.dbDebug,
|
||||
cfg.Database,
|
||||
app.nodeNotifier,
|
||||
cfg.IPPrefixes,
|
||||
cfg.BaseDomain)
|
||||
|
@ -234,8 +200,23 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
|
|||
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
|
||||
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
|
||||
var update types.StateUpdate
|
||||
var changed bool
|
||||
for range ticker.C {
|
||||
h.db.ExpireEphemeralNodes(h.cfg.EphemeralNodeInactivityTimeout)
|
||||
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
|
||||
continue
|
||||
}
|
||||
|
||||
if changed && update.Valid() {
|
||||
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, update)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -246,9 +227,24 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
|
|||
ticker := time.NewTicker(interval)
|
||||
|
||||
lastCheck := time.Unix(0, 0)
|
||||
var update types.StateUpdate
|
||||
var changed bool
|
||||
|
||||
for range ticker.C {
|
||||
lastCheck = h.db.ExpireExpiredNodes(lastCheck)
|
||||
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("database error while expiring nodes")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes")
|
||||
if changed && update.Valid() {
|
||||
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, update)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -268,7 +264,7 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
|
|||
case <-ticker.C:
|
||||
log.Info().Msg("Fetching DERPMap updates")
|
||||
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
||||
if h.cfg.DERP.ServerEnabled {
|
||||
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||
region, _ := h.DERPServer.GenerateRegion()
|
||||
h.DERPMap.Regions[region.RegionID] = ®ion
|
||||
}
|
||||
|
@ -278,7 +274,8 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
|
|||
DERPMap: h.DERPMap,
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
h.nodeNotifier.NotifyAll(stateUpdate)
|
||||
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -485,6 +482,19 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||
|
||||
// Serve launches a GIN server with the Headscale API.
|
||||
func (h *Headscale) Serve() error {
|
||||
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
|
||||
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
|
||||
err := os.MkdirAll(profilePath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to create profiling directory")
|
||||
}
|
||||
|
||||
defer profile.Start(profile.ProfilePath(profilePath)).Stop()
|
||||
} else {
|
||||
defer profile.Start().Stop()
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Fetch an initial DERP Map before we start serving
|
||||
|
@ -501,7 +511,9 @@ func (h *Headscale) Serve() error {
|
|||
return err
|
||||
}
|
||||
|
||||
if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||
h.DERPMap.Regions[region.RegionID] = ®ion
|
||||
}
|
||||
|
||||
go h.DERPServer.ServeSTUN()
|
||||
}
|
||||
|
@ -708,14 +720,16 @@ func (h *Headscale) Serve() error {
|
|||
|
||||
var tailsqlContext context.Context
|
||||
if tailsqlEnabled {
|
||||
if h.cfg.DBtype != db.Sqlite {
|
||||
log.Fatal().Str("type", h.cfg.DBtype).Msgf("tailsql only support %q", db.Sqlite)
|
||||
if h.cfg.Database.Type != types.DatabaseSqlite {
|
||||
log.Fatal().
|
||||
Str("type", h.cfg.Database.Type).
|
||||
Msgf("tailsql only support %q", types.DatabaseSqlite)
|
||||
}
|
||||
if tailsqlTSKey == "" {
|
||||
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
|
||||
}
|
||||
tailsqlContext = context.Background()
|
||||
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.DBpath)
|
||||
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path)
|
||||
}
|
||||
|
||||
// Handle common process-killing signals so we can gracefully shut down:
|
||||
|
@ -751,7 +765,8 @@ func (h *Headscale) Serve() error {
|
|||
Str("path", aclPath).
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
|
||||
h.nodeNotifier.NotifyAll(types.StateUpdate{
|
||||
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||
Type: types.StateFullUpdate,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package hscontrol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -8,6 +9,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -199,6 +201,19 @@ func (h *Headscale) handleRegister(
|
|||
return
|
||||
}
|
||||
|
||||
// When logged out and reauthenticating with OIDC, the OldNodeKey is not passed, but the NodeKey has changed
|
||||
if node.NodeKey.String() != registerRequest.NodeKey.String() &&
|
||||
registerRequest.OldNodeKey.IsZero() && !node.IsExpired() {
|
||||
h.handleNodeKeyRefresh(
|
||||
writer,
|
||||
registerRequest,
|
||||
*node,
|
||||
machineKey,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if registerRequest.Followup != "" {
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
|
@ -230,8 +245,6 @@ func (h *Headscale) handleRegister(
|
|||
|
||||
// handleAuthKey contains the logic to manage auth key client registration
|
||||
// When using Noise, the machineKey is Zero.
|
||||
//
|
||||
// TODO: check if any locks are needed around IP allocation.
|
||||
func (h *Headscale) handleAuthKey(
|
||||
writer http.ResponseWriter,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
|
@ -298,6 +311,9 @@ func (h *Headscale) handleAuthKey(
|
|||
|
||||
nodeKey := registerRequest.NodeKey
|
||||
|
||||
var update types.StateUpdate
|
||||
var mkey key.MachinePublic
|
||||
|
||||
// retrieve node information if it exist
|
||||
// The error is not important, because if it does not
|
||||
// exist, then this is a new node and we will move
|
||||
|
@ -311,7 +327,7 @@ func (h *Headscale) handleAuthKey(
|
|||
|
||||
node.NodeKey = nodeKey
|
||||
node.AuthKeyID = uint(pak.ID)
|
||||
err := h.db.NodeSetExpiry(node, registerRequest.Expiry)
|
||||
err := h.db.NodeSetExpiry(node.ID, registerRequest.Expiry)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -322,10 +338,13 @@ func (h *Headscale) handleAuthKey(
|
|||
return
|
||||
}
|
||||
|
||||
mkey = node.MachineKey
|
||||
update = types.StateUpdateExpire(node.ID, registerRequest.Expiry)
|
||||
|
||||
aclTags := pak.Proto().GetAclTags()
|
||||
if len(aclTags) > 0 {
|
||||
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
||||
err = h.db.SetTags(node, aclTags)
|
||||
err = h.db.SetTags(node.ID, aclTags)
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -357,6 +376,7 @@ func (h *Headscale) handleAuthKey(
|
|||
Hostname: registerRequest.Hostinfo.Hostname,
|
||||
GivenName: givenName,
|
||||
UserID: pak.User.ID,
|
||||
User: pak.User,
|
||||
MachineKey: machineKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Expiry: ®isterRequest.Expiry,
|
||||
|
@ -380,9 +400,18 @@ func (h *Headscale) handleAuthKey(
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
mkey = node.MachineKey
|
||||
update = types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
Message: "called from auth.handleAuthKey",
|
||||
}
|
||||
}
|
||||
|
||||
err = h.db.UsePreAuthKey(pak)
|
||||
err = h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
return db.UsePreAuthKey(tx, pak)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -424,6 +453,13 @@ func (h *Headscale) handleAuthKey(
|
|||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(kradalby): if notifying after register make sense.
|
||||
if update.Valid() {
|
||||
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
|
||||
h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String())
|
||||
}
|
||||
|
||||
log.Info().
|
||||
|
@ -489,7 +525,7 @@ func (h *Headscale) handleNodeLogOut(
|
|||
Msg("Client requested logout")
|
||||
|
||||
now := time.Now()
|
||||
err := h.db.NodeSetExpiry(&node, now)
|
||||
err := h.db.NodeSetExpiry(node.ID, now)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -500,17 +536,10 @@ func (h *Headscale) handleNodeLogOut(
|
|||
return
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: tailcfg.NodeID(node.ID),
|
||||
KeyExpiry: &now,
|
||||
},
|
||||
},
|
||||
}
|
||||
stateUpdate := types.StateUpdateExpire(node.ID, now)
|
||||
if stateUpdate.Valid() {
|
||||
h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
|
||||
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
|
||||
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
resp.AuthURL = ""
|
||||
|
@ -541,7 +570,7 @@ func (h *Headscale) handleNodeLogOut(
|
|||
}
|
||||
|
||||
if node.IsEphemeral() {
|
||||
err = h.db.DeleteNode(&node)
|
||||
err = h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap())
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
|
@ -549,6 +578,15 @@ func (h *Headscale) handleNodeLogOut(
|
|||
Msg("Cannot delete ephemeral node from the database")
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerRemoved,
|
||||
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -620,7 +658,9 @@ func (h *Headscale) handleNodeKeyRefresh(
|
|||
Str("node", node.Hostname).
|
||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||
|
||||
err := h.db.NodeSetNodeKey(&node, registerRequest.NodeKey)
|
||||
err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
|
|
@ -13,16 +13,23 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
|
||||
|
||||
func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) {
|
||||
return getAvailableIPs(rx, hsdb.ipPrefixes)
|
||||
})
|
||||
}
|
||||
|
||||
func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) {
|
||||
var ips types.NodeAddresses
|
||||
var err error
|
||||
for _, ipPrefix := range hsdb.ipPrefixes {
|
||||
for _, ipPrefix := range ipPrefixes {
|
||||
var ip *netip.Addr
|
||||
ip, err = hsdb.getAvailableIP(ipPrefix)
|
||||
ip, err = getAvailableIP(rx, ipPrefix)
|
||||
if err != nil {
|
||||
return ips, err
|
||||
}
|
||||
|
@ -32,8 +39,8 @@ func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
|
|||
return ips, err
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
|
||||
usedIps, err := hsdb.getUsedIPs()
|
||||
func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) {
|
||||
usedIps, err := getUsedIPs(rx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -64,12 +71,12 @@ func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, erro
|
|||
}
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) {
|
||||
func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||
// FIXME: This really deserves a better data model,
|
||||
// but this was quick to get running and it should be enough
|
||||
// to begin experimenting with a dual stack tailnet.
|
||||
var addressesSlices []string
|
||||
hsdb.db.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
|
||||
rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
|
||||
|
||||
var ips netipx.IPSetBuilder
|
||||
for _, slice := range addressesSlices {
|
||||
|
|
|
@ -7,10 +7,16 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||
ips, err := db.getAvailableIPs()
|
||||
tx := db.DB.Begin()
|
||||
defer tx.Rollback()
|
||||
|
||||
ips, err := getAvailableIPs(tx, []netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
})
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
|
@ -30,7 +36,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
|||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("test", "testnode")
|
||||
_, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
node := types.Node{
|
||||
|
@ -41,10 +47,13 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
|||
AuthKeyID: uint(pak.ID),
|
||||
IPAddresses: ips,
|
||||
}
|
||||
db.db.Save(&node)
|
||||
|
||||
usedIps, err := db.getUsedIPs()
|
||||
db.Write(func(tx *gorm.DB) error {
|
||||
return tx.Save(&node).Error
|
||||
})
|
||||
|
||||
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||
return getUsedIPs(rx)
|
||||
})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netip.MustParseAddr("10.27.0.1")
|
||||
|
@ -63,19 +72,23 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
|||
}
|
||||
|
||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
user, err := db.CreateUser("test-ip-multi")
|
||||
user, err := db.CreateUser("test-ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
ipPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
}
|
||||
|
||||
for index := 1; index <= 350; index++ {
|
||||
db.ipAllocationMutex.Lock()
|
||||
tx := db.DB.Begin()
|
||||
|
||||
ips, err := db.getAvailableIPs()
|
||||
ips, err := getAvailableIPs(tx, ipPrefixes)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("test", "testnode")
|
||||
_, err = getNode(tx, "test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
node := types.Node{
|
||||
|
@ -86,12 +99,13 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
|||
AuthKeyID: uint(pak.ID),
|
||||
IPAddresses: ips,
|
||||
}
|
||||
db.db.Save(&node)
|
||||
|
||||
db.ipAllocationMutex.Unlock()
|
||||
tx.Save(&node)
|
||||
c.Assert(tx.Commit().Error, check.IsNil)
|
||||
}
|
||||
|
||||
usedIps, err := db.getUsedIPs()
|
||||
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||
return getUsedIPs(rx)
|
||||
})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected0 := netip.MustParseAddr("10.27.0.1")
|
||||
|
@ -162,7 +176,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
|
|||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("test", "testnode")
|
||||
_, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
node := types.Node{
|
||||
|
@ -172,7 +186,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
|
||||
ips2, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
|
|
@ -22,9 +22,6 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
|
|||
func (hsdb *HSDatabase) CreateAPIKey(
|
||||
expiration *time.Time,
|
||||
) (string, *types.APIKey, error) {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
|
@ -49,7 +46,7 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
|||
Expiration: expiration,
|
||||
}
|
||||
|
||||
if err := hsdb.db.Save(&key).Error; err != nil {
|
||||
if err := hsdb.DB.Save(&key).Error; err != nil {
|
||||
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
|
||||
}
|
||||
|
||||
|
@ -58,11 +55,8 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
|||
|
||||
// ListAPIKeys returns the list of ApiKeys for a user.
|
||||
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
keys := []types.APIKey{}
|
||||
if err := hsdb.db.Find(&keys).Error; err != nil {
|
||||
if err := hsdb.DB.Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -71,11 +65,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
|||
|
||||
// GetAPIKey returns a ApiKey for a given key.
|
||||
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
key := types.APIKey{}
|
||||
if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil {
|
||||
if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
|
@ -84,11 +75,8 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
|
|||
|
||||
// GetAPIKeyByID returns a ApiKey for a given id.
|
||||
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
key := types.APIKey{}
|
||||
if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
|
||||
if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
|
@ -98,10 +86,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
|
|||
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
|
||||
// does not exist.
|
||||
func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
if result := hsdb.db.Unscoped().Delete(key); result.Error != nil {
|
||||
if result := hsdb.DB.Unscoped().Delete(key); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
@ -110,10 +95,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
|||
|
||||
// ExpireAPIKey marks a ApiKey as expired.
|
||||
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
||||
if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -121,9 +103,6 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
prefix, hash, found := strings.Cut(keyStr, ".")
|
||||
if !found {
|
||||
return false, ErrAPIKeyFailedToParse
|
||||
|
|
|
@ -6,24 +6,20 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/go-gormigrate/gormigrate/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
Postgres = "postgres"
|
||||
Sqlite = "sqlite3"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
var errDatabaseNotSupported = errors.New("database type not supported")
|
||||
|
@ -36,12 +32,7 @@ type KV struct {
|
|||
}
|
||||
|
||||
type HSDatabase struct {
|
||||
db *gorm.DB
|
||||
notifier *notifier.Notifier
|
||||
|
||||
mu sync.RWMutex
|
||||
|
||||
ipAllocationMutex sync.Mutex
|
||||
DB *gorm.DB
|
||||
|
||||
ipPrefixes []netip.Prefix
|
||||
baseDomain string
|
||||
|
@ -50,18 +41,20 @@ type HSDatabase struct {
|
|||
// TODO(kradalby): assemble this struct from toptions or something typed
|
||||
// rather than arguments.
|
||||
func NewHeadscaleDatabase(
|
||||
dbType, connectionAddr string,
|
||||
debug bool,
|
||||
cfg types.DatabaseConfig,
|
||||
notifier *notifier.Notifier,
|
||||
ipPrefixes []netip.Prefix,
|
||||
baseDomain string,
|
||||
) (*HSDatabase, error) {
|
||||
dbConn, err := openDB(dbType, connectionAddr, debug)
|
||||
dbConn, err := openDB(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrations := gormigrate.New(dbConn, gormigrate.DefaultOptions, []*gormigrate.Migration{
|
||||
migrations := gormigrate.New(
|
||||
dbConn,
|
||||
gormigrate.DefaultOptions,
|
||||
[]*gormigrate.Migration{
|
||||
// New migrations should be added as transactions at the end of this list.
|
||||
// The initial commit here is quite messy, completely out of order and
|
||||
// has no versioning and is the tech debt of not having versioned migrations
|
||||
|
@ -70,7 +63,7 @@ func NewHeadscaleDatabase(
|
|||
{
|
||||
ID: "202312101416",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
if dbType == Postgres {
|
||||
if cfg.Type == types.DatabasePostgres {
|
||||
tx.Exec(`create extension if not exists "uuid-ossp";`)
|
||||
}
|
||||
|
||||
|
@ -78,23 +71,28 @@ func NewHeadscaleDatabase(
|
|||
|
||||
// the big rename from Machine to Node
|
||||
_ = tx.Migrator().RenameTable("machines", "nodes")
|
||||
_ = tx.Migrator().RenameColumn(&types.Route{}, "machine_id", "node_id")
|
||||
_ = tx.Migrator().
|
||||
RenameColumn(&types.Route{}, "machine_id", "node_id")
|
||||
|
||||
err = tx.AutoMigrate(types.User{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_ = tx.Migrator().RenameColumn(&types.Node{}, "namespace_id", "user_id")
|
||||
_ = tx.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
|
||||
_ = tx.Migrator().
|
||||
RenameColumn(&types.Node{}, "namespace_id", "user_id")
|
||||
_ = tx.Migrator().
|
||||
RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
|
||||
|
||||
_ = tx.Migrator().RenameColumn(&types.Node{}, "ip_address", "ip_addresses")
|
||||
_ = tx.Migrator().
|
||||
RenameColumn(&types.Node{}, "ip_address", "ip_addresses")
|
||||
_ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname")
|
||||
|
||||
// GivenName is used as the primary source of DNS names, make sure
|
||||
// the field is populated and normalized if it was not when the
|
||||
// node was registered.
|
||||
_ = tx.Migrator().RenameColumn(&types.Node{}, "nickname", "given_name")
|
||||
_ = tx.Migrator().
|
||||
RenameColumn(&types.Node{}, "nickname", "given_name")
|
||||
|
||||
// If the Node table has a column for registered,
|
||||
// find all occourences of "false" and drop them. Then
|
||||
|
@ -147,7 +145,9 @@ func NewHeadscaleDatabase(
|
|||
DiscoKey string
|
||||
}
|
||||
var results []result
|
||||
err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error
|
||||
err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").
|
||||
Find(&results).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -180,7 +180,8 @@ func NewHeadscaleDatabase(
|
|||
}
|
||||
|
||||
if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") {
|
||||
log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...")
|
||||
log.Info().
|
||||
Msgf("Database has legacy enabled_routes column in node, migrating...")
|
||||
|
||||
type NodeAux struct {
|
||||
ID uint64
|
||||
|
@ -188,7 +189,10 @@ func NewHeadscaleDatabase(
|
|||
}
|
||||
|
||||
nodesAux := []NodeAux{}
|
||||
err := tx.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error
|
||||
err := tx.Table("nodes").
|
||||
Select("id, enabled_routes").
|
||||
Scan(&nodesAux).
|
||||
Error
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Error accessing db")
|
||||
}
|
||||
|
@ -234,7 +238,9 @@ func NewHeadscaleDatabase(
|
|||
|
||||
err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error dropping enabled_routes column")
|
||||
log.Error().
|
||||
Err(err).
|
||||
Msg("Error dropping enabled_routes column")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -310,15 +316,15 @@ func NewHeadscaleDatabase(
|
|||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
if err = migrations.Migrate(); err != nil {
|
||||
log.Fatal().Err(err).Msgf("Migration failed: %v", err)
|
||||
}
|
||||
|
||||
db := HSDatabase{
|
||||
db: dbConn,
|
||||
notifier: notifier,
|
||||
DB: dbConn,
|
||||
|
||||
ipPrefixes: ipPrefixes,
|
||||
baseDomain: baseDomain,
|
||||
|
@ -327,20 +333,19 @@ func NewHeadscaleDatabase(
|
|||
return &db, err
|
||||
}
|
||||
|
||||
func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
||||
log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database")
|
||||
|
||||
func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
||||
// TODO(kradalby): Integrate this with zerolog
|
||||
var dbLogger logger.Interface
|
||||
if debug {
|
||||
if cfg.Debug {
|
||||
dbLogger = logger.Default
|
||||
} else {
|
||||
dbLogger = logger.Default.LogMode(logger.Silent)
|
||||
}
|
||||
|
||||
switch dbType {
|
||||
case Sqlite:
|
||||
switch cfg.Type {
|
||||
case types.DatabaseSqlite:
|
||||
db, err := gorm.Open(
|
||||
sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"),
|
||||
sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
|
||||
&gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
Logger: dbLogger,
|
||||
|
@ -359,16 +364,51 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
|||
|
||||
return db, err
|
||||
|
||||
case Postgres:
|
||||
return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{
|
||||
case types.DatabasePostgres:
|
||||
dbString := fmt.Sprintf(
|
||||
"host=%s dbname=%s user=%s",
|
||||
cfg.Postgres.Host,
|
||||
cfg.Postgres.Name,
|
||||
cfg.Postgres.User,
|
||||
)
|
||||
|
||||
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
|
||||
if !sslEnabled {
|
||||
dbString += " sslmode=disable"
|
||||
}
|
||||
} else {
|
||||
dbString += fmt.Sprintf(" sslmode=%s", cfg.Postgres.Ssl)
|
||||
}
|
||||
|
||||
if cfg.Postgres.Port != 0 {
|
||||
dbString += fmt.Sprintf(" port=%d", cfg.Postgres.Port)
|
||||
}
|
||||
|
||||
if cfg.Postgres.Pass != "" {
|
||||
dbString += fmt.Sprintf(" password=%s", cfg.Postgres.Pass)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dbString), &gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
Logger: dbLogger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.SetMaxIdleConns(cfg.Postgres.MaxIdleConnections)
|
||||
sqlDB.SetMaxOpenConns(cfg.Postgres.MaxOpenConnections)
|
||||
sqlDB.SetConnMaxIdleTime(
|
||||
time.Duration(cfg.Postgres.ConnMaxIdleTimeSecs) * time.Second,
|
||||
)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(
|
||||
"database of type %s is not supported: %w",
|
||||
dbType,
|
||||
cfg.Type,
|
||||
errDatabaseNotSupported,
|
||||
)
|
||||
}
|
||||
|
@ -376,7 +416,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
|||
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
sqlDB, err := hsdb.db.DB()
|
||||
sqlDB, err := hsdb.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -385,10 +425,48 @@ func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) Close() error {
|
||||
db, err := hsdb.db.DB()
|
||||
db, err := hsdb.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.Close()
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
|
||||
rx := hsdb.DB.Begin()
|
||||
defer rx.Rollback()
|
||||
return fn(rx)
|
||||
}
|
||||
|
||||
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
|
||||
rx := db.Begin()
|
||||
defer rx.Rollback()
|
||||
ret, err := fn(rx)
|
||||
if err != nil {
|
||||
var no T
|
||||
return no, err
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
|
||||
tx := hsdb.DB.Begin()
|
||||
defer tx.Rollback()
|
||||
if err := fn(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
|
||||
tx := db.Begin()
|
||||
defer tx.Rollback()
|
||||
ret, err := fn(tx)
|
||||
if err != nil {
|
||||
var no T
|
||||
return no, err
|
||||
}
|
||||
return ret, tx.Commit().Error
|
||||
}
|
||||
|
|
|
@ -34,22 +34,21 @@ var (
|
|||
)
|
||||
)
|
||||
|
||||
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
|
||||
func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.listPeers(node)
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return ListPeers(rx, node)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
|
||||
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
|
||||
func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
Msg("Finding direct peers")
|
||||
|
||||
nodes := types.Nodes{}
|
||||
if err := hsdb.db.
|
||||
if err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
|
@ -64,16 +63,15 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
|
|||
return nodes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListNodes() ([]types.Node, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.listNodes()
|
||||
func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return ListNodes(rx)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) listNodes() ([]types.Node, error) {
|
||||
nodes := []types.Node{}
|
||||
if err := hsdb.db.
|
||||
func ListNodes(tx *gorm.DB) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
|
@ -85,16 +83,9 @@ func (hsdb *HSDatabase) listNodes() ([]types.Node, error) {
|
|||
return nodes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListNodesByGivenName(givenName string) (types.Nodes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.listNodesByGivenName(givenName)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, error) {
|
||||
func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := hsdb.db.
|
||||
if err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
|
@ -106,12 +97,15 @@ func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, err
|
|||
return nodes, nil
|
||||
}
|
||||
|
||||
// GetNode finds a Node by name and user and returns the Node struct.
|
||||
func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return getNode(rx, user, name)
|
||||
})
|
||||
}
|
||||
|
||||
nodes, err := hsdb.ListNodesByUser(user)
|
||||
// getNode finds a Node by name and user and returns the Node struct.
|
||||
func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
|
||||
nodes, err := ListNodesByUser(tx, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -125,34 +119,16 @@ func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) {
|
|||
return nil, ErrNodeNotFound
|
||||
}
|
||||
|
||||
// GetNodeByGivenName finds a Node by given name and user and returns the Node struct.
|
||||
func (hsdb *HSDatabase) GetNodeByGivenName(
|
||||
user string,
|
||||
givenName string,
|
||||
) (*types.Node, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
node := types.Node{}
|
||||
if err := hsdb.db.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Preload("Routes").
|
||||
Where("given_name = ?", givenName).First(&node).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, ErrNodeNotFound
|
||||
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByID(rx, id)
|
||||
})
|
||||
}
|
||||
|
||||
// GetNodeByID finds a Node by ID and returns the Node struct.
|
||||
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) {
|
||||
mach := types.Node{}
|
||||
if result := hsdb.db.
|
||||
if result := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
|
@ -164,21 +140,19 @@ func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
|
|||
return &mach, nil
|
||||
}
|
||||
|
||||
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
|
||||
func (hsdb *HSDatabase) GetNodeByMachineKey(
|
||||
machineKey key.MachinePublic,
|
||||
) (*types.Node, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getNodeByMachineKey(machineKey)
|
||||
func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByMachineKey(rx, machineKey)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getNodeByMachineKey(
|
||||
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
|
||||
func GetNodeByMachineKey(
|
||||
tx *gorm.DB,
|
||||
machineKey key.MachinePublic,
|
||||
) (*types.Node, error) {
|
||||
mach := types.Node{}
|
||||
if result := hsdb.db.
|
||||
if result := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
|
@ -190,36 +164,24 @@ func (hsdb *HSDatabase) getNodeByMachineKey(
|
|||
return &mach, nil
|
||||
}
|
||||
|
||||
// GetNodeByNodeKey finds a Node by its current NodeKey.
|
||||
func (hsdb *HSDatabase) GetNodeByNodeKey(
|
||||
func (hsdb *HSDatabase) GetNodeByAnyKey(
|
||||
machineKey key.MachinePublic,
|
||||
nodeKey key.NodePublic,
|
||||
oldNodeKey key.NodePublic,
|
||||
) (*types.Node, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
node := types.Node{}
|
||||
if result := hsdb.db.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Preload("Routes").
|
||||
First(&node, "node_key = ?",
|
||||
nodeKey.String()); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &node, nil
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey)
|
||||
})
|
||||
}
|
||||
|
||||
// GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct.
|
||||
func (hsdb *HSDatabase) GetNodeByAnyKey(
|
||||
// TODO(kradalby): see if we can remove this.
|
||||
func GetNodeByAnyKey(
|
||||
tx *gorm.DB,
|
||||
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
|
||||
) (*types.Node, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
node := types.Node{}
|
||||
if result := hsdb.db.
|
||||
if result := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
|
@ -234,60 +196,44 @@ func (hsdb *HSDatabase) GetNodeByAnyKey(
|
|||
return &node, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) NodeReloadFromDatabase(node *types.Node) error {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
if result := hsdb.db.Find(node).First(&node); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
func (hsdb *HSDatabase) SetTags(
|
||||
nodeID uint64,
|
||||
tags []string,
|
||||
) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return SetTags(tx, nodeID, tags)
|
||||
})
|
||||
}
|
||||
|
||||
// SetTags takes a Node struct pointer and update the forced tags.
|
||||
func (hsdb *HSDatabase) SetTags(
|
||||
node *types.Node,
|
||||
func SetTags(
|
||||
tx *gorm.DB,
|
||||
nodeID uint64,
|
||||
tags []string,
|
||||
) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
if len(tags) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
newTags := []string{}
|
||||
newTags := types.StringList{}
|
||||
for _, tag := range tags {
|
||||
if !util.StringOrPrefixListContains(newTags, tag) {
|
||||
newTags = append(newTags, tag)
|
||||
}
|
||||
}
|
||||
|
||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
||||
ForcedTags: newTags,
|
||||
}).Error; err != nil {
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
|
||||
return fmt.Errorf("failed to update tags for node in the database: %w", err)
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
Message: "called from db.SetTags",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RenameNode takes a Node struct and a new GivenName for the nodes
|
||||
// and renames it.
|
||||
func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
func RenameNode(tx *gorm.DB,
|
||||
nodeID uint64, newName string,
|
||||
) error {
|
||||
err := util.CheckForFQDNRules(
|
||||
newName,
|
||||
)
|
||||
|
@ -295,129 +241,74 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
|
|||
log.Error().
|
||||
Caller().
|
||||
Str("func", "RenameNode").
|
||||
Str("node", node.Hostname).
|
||||
Uint64("nodeID", nodeID).
|
||||
Str("newName", newName).
|
||||
Err(err).
|
||||
Msg("failed to rename node")
|
||||
|
||||
return err
|
||||
}
|
||||
node.GivenName = newName
|
||||
|
||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
||||
GivenName: newName,
|
||||
}).Error; err != nil {
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
||||
return fmt.Errorf("failed to rename node in the database: %w", err)
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
Message: "called from db.RenameNode",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return NodeSetExpiry(tx, nodeID, expiry)
|
||||
})
|
||||
}
|
||||
|
||||
// NodeSetExpiry takes a Node struct and a new expiry time.
|
||||
func (hsdb *HSDatabase) NodeSetExpiry(node *types.Node, expiry time.Time) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.nodeSetExpiry(node, expiry)
|
||||
func NodeSetExpiry(tx *gorm.DB,
|
||||
nodeID uint64, expiry time.Time,
|
||||
) error {
|
||||
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error {
|
||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
||||
Expiry: &expiry,
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to refresh node (update expiration) in the database: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
node.Expiry = &expiry
|
||||
|
||||
stateSelfUpdate := types.StateUpdate{
|
||||
Type: types.StateSelfUpdate,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
}
|
||||
if stateSelfUpdate.Valid() {
|
||||
hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: tailcfg.NodeID(node.ID),
|
||||
KeyExpiry: &expiry,
|
||||
},
|
||||
},
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.MachinePublic]bool) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return DeleteNode(tx, node, isConnected)
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteNode deletes a Node from the database.
|
||||
func (hsdb *HSDatabase) DeleteNode(node *types.Node) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.deleteNode(node)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) deleteNode(node *types.Node) error {
|
||||
err := hsdb.deleteNodeRoutes(node)
|
||||
// Caller is responsible for notifying all of change.
|
||||
func DeleteNode(tx *gorm.DB,
|
||||
node *types.Node,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
) error {
|
||||
err := deleteNodeRoutes(tx, node, map[key.MachinePublic]bool{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Unscoped causes the node to be fully removed from the database.
|
||||
if err := hsdb.db.Unscoped().Delete(&node).Error; err != nil {
|
||||
if err := tx.Unscoped().Delete(&node).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerRemoved,
|
||||
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyAll(stateUpdate)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateLastSeen sets a node's last seen field indicating that we
|
||||
// have recently communicating with this node.
|
||||
// This is mostly used to indicate if a node is online and is not
|
||||
// extremely important to make sure is fully correct and to avoid
|
||||
// holding up the hot path, does not contain any locks and isnt
|
||||
// concurrency safe. But that should be ok.
|
||||
func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error {
|
||||
return hsdb.db.Model(node).Updates(types.Node{
|
||||
LastSeen: node.LastSeen,
|
||||
}).Error
|
||||
func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error {
|
||||
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
||||
func RegisterNodeFromAuthCallback(
|
||||
tx *gorm.DB,
|
||||
cache *cache.Cache,
|
||||
mkey key.MachinePublic,
|
||||
userName string,
|
||||
nodeExpiry *time.Time,
|
||||
registrationMethod string,
|
||||
ipPrefixes []netip.Prefix,
|
||||
) (*types.Node, error) {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
log.Debug().
|
||||
Str("machine_key", mkey.ShortString()).
|
||||
Str("userName", userName).
|
||||
|
@ -427,7 +318,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
|||
|
||||
if nodeInterface, ok := cache.Get(mkey.String()); ok {
|
||||
if registrationNode, ok := nodeInterface.(types.Node); ok {
|
||||
user, err := hsdb.getUser(userName)
|
||||
user, err := GetUser(tx, userName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to find user in register node from auth callback, %w",
|
||||
|
@ -442,14 +333,17 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
|||
}
|
||||
|
||||
registrationNode.UserID = user.ID
|
||||
registrationNode.User = *user
|
||||
registrationNode.RegisterMethod = registrationMethod
|
||||
|
||||
if nodeExpiry != nil {
|
||||
registrationNode.Expiry = nodeExpiry
|
||||
}
|
||||
|
||||
node, err := hsdb.registerNode(
|
||||
node, err := RegisterNode(
|
||||
tx,
|
||||
registrationNode,
|
||||
ipPrefixes,
|
||||
)
|
||||
|
||||
if err == nil {
|
||||
|
@ -465,15 +359,14 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
|||
return nil, ErrNodeNotFoundRegistrationCache
|
||||
}
|
||||
|
||||
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.registerNode(node)
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
return RegisterNode(tx, node, hsdb.ipPrefixes)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
||||
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
|
||||
func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) {
|
||||
log.Debug().
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
|
@ -485,7 +378,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
|||
// so we store the node.Expire and node.Nodekey that has been set when
|
||||
// adding it to the registrationCache
|
||||
if len(node.IPAddresses) > 0 {
|
||||
if err := hsdb.db.Save(&node).Error; err != nil {
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register existing node in the database: %w", err)
|
||||
}
|
||||
|
||||
|
@ -500,10 +393,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
|||
return &node, nil
|
||||
}
|
||||
|
||||
hsdb.ipAllocationMutex.Lock()
|
||||
defer hsdb.ipAllocationMutex.Unlock()
|
||||
|
||||
ips, err := hsdb.getAvailableIPs()
|
||||
ips, err := getAvailableIPs(tx, ipPrefixes)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -516,7 +406,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
|||
|
||||
node.IPAddresses = ips
|
||||
|
||||
if err := hsdb.db.Save(&node).Error; err != nil {
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
|
||||
}
|
||||
|
||||
|
@ -530,61 +420,50 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
|||
}
|
||||
|
||||
// NodeSetNodeKey sets the node key of a node and saves it to the database.
|
||||
func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
||||
func NodeSetNodeKey(tx *gorm.DB, node *types.Node, nodeKey key.NodePublic) error {
|
||||
return tx.Model(node).Updates(types.Node{
|
||||
NodeKey: nodeKey,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NodeSetMachineKey sets the node key of a node and saves it to the database.
|
||||
func (hsdb *HSDatabase) NodeSetMachineKey(
|
||||
node *types.Node,
|
||||
machineKey key.MachinePublic,
|
||||
) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
||||
MachineKey: machineKey,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return NodeSetMachineKey(tx, node, machineKey)
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
// NodeSetMachineKey sets the node key of a node and saves it to the database.
|
||||
func NodeSetMachineKey(
|
||||
tx *gorm.DB,
|
||||
node *types.Node,
|
||||
machineKey key.MachinePublic,
|
||||
) error {
|
||||
return tx.Model(node).Updates(types.Node{
|
||||
MachineKey: machineKey,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// NodeSave saves a node object to the database, prefer to use a specific save method rather
|
||||
// than this. It is intended to be used when we are changing or.
|
||||
func (hsdb *HSDatabase) NodeSave(node *types.Node) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
if err := hsdb.db.Save(node).Error; err != nil {
|
||||
return err
|
||||
// TODO(kradalby): Remove this func, just use Save.
|
||||
func NodeSave(tx *gorm.DB, node *types.Node) error {
|
||||
return tx.Save(node).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
|
||||
return GetAdvertisedRoutes(rx, node)
|
||||
})
|
||||
}
|
||||
|
||||
// GetAdvertisedRoutes returns the routes that are be advertised by the given node.
|
||||
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getAdvertisedRoutes(node)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
|
||||
func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
|
||||
routes := types.Routes{}
|
||||
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
|
@ -605,18 +484,17 @@ func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, e
|
|||
return prefixes, nil
|
||||
}
|
||||
|
||||
// GetEnabledRoutes returns the routes that are enabled for the node.
|
||||
func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getEnabledRoutes(node)
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
|
||||
return GetEnabledRoutes(rx, node)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
|
||||
// GetEnabledRoutes returns the routes that are enabled for the node.
|
||||
func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
|
||||
routes := types.Routes{}
|
||||
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true).
|
||||
Find(&routes).Error
|
||||
|
@ -638,16 +516,13 @@ func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, erro
|
|||
return prefixes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool {
|
||||
route, err := netip.ParsePrefix(routeStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
enabledRoutes, err := hsdb.getEnabledRoutes(node)
|
||||
enabledRoutes, err := GetEnabledRoutes(tx, node)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Could not get enabled routes")
|
||||
|
||||
|
@ -663,26 +538,37 @@ func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool
|
|||
return false
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) enableRoutes(
|
||||
node *types.Node,
|
||||
routeStrs ...string,
|
||||
) (*types.StateUpdate, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return enableRoutes(tx, node, routeStrs...)
|
||||
})
|
||||
}
|
||||
|
||||
// enableRoutes enables new routes based on a list of new routes.
|
||||
func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error {
|
||||
func enableRoutes(tx *gorm.DB,
|
||||
node *types.Node, routeStrs ...string,
|
||||
) (*types.StateUpdate, error) {
|
||||
newRoutes := make([]netip.Prefix, len(routeStrs))
|
||||
for index, routeStr := range routeStrs {
|
||||
route, err := netip.ParsePrefix(routeStr)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRoutes[index] = route
|
||||
}
|
||||
|
||||
advertisedRoutes, err := hsdb.getAdvertisedRoutes(node)
|
||||
advertisedRoutes, err := GetAdvertisedRoutes(tx, node)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
|
||||
return fmt.Errorf(
|
||||
return nil, fmt.Errorf(
|
||||
"route (%s) is not available on node %s: %w",
|
||||
node.Hostname,
|
||||
newRoute, ErrNodeRouteIsNotAvailable,
|
||||
|
@ -693,7 +579,7 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
|
|||
// Separate loop so we don't leave things in a half-updated state
|
||||
for _, prefix := range newRoutes {
|
||||
route := types.Route{}
|
||||
err := hsdb.db.Preload("Node").
|
||||
err := tx.Preload("Node").
|
||||
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
|
||||
First(&route).Error
|
||||
if err == nil {
|
||||
|
@ -702,23 +588,23 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
|
|||
// Mark already as primary if there is only this node offering this subnet
|
||||
// (and is not an exit route)
|
||||
if !route.IsExitRoute() {
|
||||
route.IsPrimary = hsdb.isUniquePrefix(route)
|
||||
route.IsPrimary = isUniquePrefix(tx, route)
|
||||
}
|
||||
|
||||
err = hsdb.db.Save(&route).Error
|
||||
err = tx.Save(&route).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to enable route: %w", err)
|
||||
return nil, fmt.Errorf("failed to enable route: %w", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to find route: %w", err)
|
||||
return nil, fmt.Errorf("failed to find route: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the node has the latest routes when notifying the other
|
||||
// nodes
|
||||
nRoutes, err := hsdb.getNodeRoutes(node)
|
||||
nRoutes, err := GetNodeRoutes(tx, node)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read back routes: %w", err)
|
||||
return nil, fmt.Errorf("failed to read back routes: %w", err)
|
||||
}
|
||||
|
||||
node.Routes = nRoutes
|
||||
|
@ -729,17 +615,11 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
|
|||
Strs("routes", routeStrs).
|
||||
Msg("enabling routes")
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
return &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
Message: "called from db.enableRoutes",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyWithIgnore(
|
||||
stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
Message: "created in db.enableRoutes",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||
|
@ -772,16 +652,23 @@ func (hsdb *HSDatabase) GenerateGivenName(
|
|||
mkey key.MachinePublic,
|
||||
suppliedName string,
|
||||
) (string, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (string, error) {
|
||||
return GenerateGivenName(rx, mkey, suppliedName)
|
||||
})
|
||||
}
|
||||
|
||||
func GenerateGivenName(
|
||||
tx *gorm.DB,
|
||||
mkey key.MachinePublic,
|
||||
suppliedName string,
|
||||
) (string, error) {
|
||||
givenName, err := generateGivenName(suppliedName, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
|
||||
nodes, err := hsdb.listNodesByGivenName(givenName)
|
||||
nodes, err := listNodesByGivenName(tx, givenName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -805,29 +692,28 @@ func (hsdb *HSDatabase) GenerateGivenName(
|
|||
return givenName, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
users, err := hsdb.listUsers()
|
||||
func ExpireEphemeralNodes(tx *gorm.DB,
|
||||
inactivityThreshhold time.Duration,
|
||||
) (types.StateUpdate, bool) {
|
||||
users, err := ListUsers(tx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error listing users")
|
||||
|
||||
return
|
||||
return types.StateUpdate{}, false
|
||||
}
|
||||
|
||||
expired := make([]tailcfg.NodeID, 0)
|
||||
for _, user := range users {
|
||||
nodes, err := hsdb.listNodesByUser(user.Name)
|
||||
nodes, err := ListNodesByUser(tx, user.Name)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("user", user.Name).
|
||||
Msg("Error listing nodes in user")
|
||||
|
||||
return
|
||||
return types.StateUpdate{}, false
|
||||
}
|
||||
|
||||
expired := make([]tailcfg.NodeID, 0)
|
||||
for idx, node := range nodes {
|
||||
if node.IsEphemeral() && node.LastSeen != nil &&
|
||||
time.Now().
|
||||
|
@ -838,7 +724,8 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration)
|
|||
Str("node", node.Hostname).
|
||||
Msg("Ephemeral client removed from database")
|
||||
|
||||
err = hsdb.deleteNode(nodes[idx])
|
||||
// empty isConnected map as ephemeral nodes are not routes
|
||||
err = DeleteNode(tx, nodes[idx], map[key.MachinePublic]bool{})
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
|
@ -848,33 +735,35 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration)
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(kradalby): needs to be moved out of transaction
|
||||
}
|
||||
if len(expired) > 0 {
|
||||
hsdb.notifier.NotifyAll(types.StateUpdate{
|
||||
return types.StateUpdate{
|
||||
Type: types.StatePeerRemoved,
|
||||
Removed: expired,
|
||||
})
|
||||
}
|
||||
}
|
||||
}, true
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
return types.StateUpdate{}, false
|
||||
}
|
||||
|
||||
func ExpireExpiredNodes(tx *gorm.DB,
|
||||
lastCheck time.Time,
|
||||
) (time.Time, types.StateUpdate, bool) {
|
||||
// use the time of the start of the function to ensure we
|
||||
// dont miss some nodes by returning it _after_ we have
|
||||
// checked everything.
|
||||
started := time.Now()
|
||||
|
||||
expiredNodes := make([]*types.Node, 0)
|
||||
expired := make([]*tailcfg.PeerChange, 0)
|
||||
|
||||
nodes, err := hsdb.listNodes()
|
||||
nodes, err := ListNodes(tx)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Msg("Error listing nodes to find expired nodes")
|
||||
|
||||
return time.Unix(0, 0)
|
||||
return time.Unix(0, 0), types.StateUpdate{}, false
|
||||
}
|
||||
for index, node := range nodes {
|
||||
if node.IsExpired() &&
|
||||
|
@ -882,13 +771,17 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
|
|||
// It will notify about all nodes that has been expired.
|
||||
// It should only notify about expired nodes since _last check_.
|
||||
node.Expiry.After(lastCheck) {
|
||||
expiredNodes = append(expiredNodes, &nodes[index])
|
||||
expired = append(expired, &tailcfg.PeerChange{
|
||||
NodeID: tailcfg.NodeID(node.ID),
|
||||
KeyExpiry: node.Expiry,
|
||||
})
|
||||
|
||||
now := time.Now()
|
||||
// Do not use setNodeExpiry as that has a notifier hook, which
|
||||
// can cause a deadlock, we are updating all changed nodes later
|
||||
// and there is no point in notifiying twice.
|
||||
if err := hsdb.db.Model(nodes[index]).Updates(types.Node{
|
||||
Expiry: &started,
|
||||
if err := tx.Model(&nodes[index]).Updates(types.Node{
|
||||
Expiry: &now,
|
||||
}).Error; err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
|
@ -904,33 +797,12 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
|
|||
}
|
||||
}
|
||||
|
||||
expired := make([]*tailcfg.PeerChange, len(expiredNodes))
|
||||
for idx, node := range expiredNodes {
|
||||
expired[idx] = &tailcfg.PeerChange{
|
||||
NodeID: tailcfg.NodeID(node.ID),
|
||||
KeyExpiry: &started,
|
||||
}
|
||||
}
|
||||
|
||||
// Inform the peers of a node with a lightweight update.
|
||||
stateUpdate := types.StateUpdate{
|
||||
if len(expired) > 0 {
|
||||
return started, types.StateUpdate{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: expired,
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyAll(stateUpdate)
|
||||
}, true
|
||||
}
|
||||
|
||||
// Inform the node itself that it has expired.
|
||||
for _, node := range expiredNodes {
|
||||
stateSelfUpdate := types.StateUpdate{
|
||||
Type: types.StateSelfUpdate,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
}
|
||||
if stateSelfUpdate.Valid() {
|
||||
hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
|
||||
}
|
||||
}
|
||||
|
||||
return started
|
||||
return started, types.StateUpdate{}, false
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ func (s *Suite) TestGetNode(c *check.C) {
|
|||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("test", "testnode")
|
||||
_, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
|
@ -38,9 +38,9 @@ func (s *Suite) TestGetNode(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(node)
|
||||
db.DB.Save(node)
|
||||
|
||||
_, err = db.GetNode("test", "testnode")
|
||||
_, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
|
@ -66,40 +66,12 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
|
||||
_, err = db.GetNodeByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetNodeByNodeKey(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNodeByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
|
||||
_, err = db.GetNodeByNodeKey(nodeKey.Public())
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -124,7 +96,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
|
||||
_, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -146,12 +118,12 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(1),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
|
||||
err = db.DeleteNode(&node)
|
||||
err = db.DeleteNode(&node, map[key.MachinePublic]bool{})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode(user.Name, "testnode3")
|
||||
_, err = db.getNode(user.Name, "testnode3")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
|
@ -178,7 +150,7 @@ func (s *Suite) TestListPeers(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
}
|
||||
|
||||
node0ByID, err := db.GetNodeByID(0)
|
||||
|
@ -228,7 +200,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(stor[index%2].key.ID),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
}
|
||||
|
||||
aclPolicy := &policy.ACLPolicy{
|
||||
|
@ -295,7 +267,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
|||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("test", "testnode")
|
||||
_, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
|
@ -311,16 +283,19 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
|||
AuthKeyID: uint(pak.ID),
|
||||
Expiry: &time.Time{},
|
||||
}
|
||||
db.db.Save(node)
|
||||
db.DB.Save(node)
|
||||
|
||||
nodeFromDB, err := db.GetNode("test", "testnode")
|
||||
nodeFromDB, err := db.getNode("test", "testnode")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(nodeFromDB, check.NotNil)
|
||||
|
||||
c.Assert(nodeFromDB.IsExpired(), check.Equals, false)
|
||||
|
||||
now := time.Now()
|
||||
err = db.NodeSetExpiry(nodeFromDB, now)
|
||||
err = db.NodeSetExpiry(nodeFromDB.ID, now)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
nodeFromDB, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
|
||||
|
@ -354,7 +329,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
|
|||
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("user-1", "testnode")
|
||||
_, err = db.getNode("user-1", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
|
@ -372,7 +347,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(node)
|
||||
db.DB.Save(node)
|
||||
|
||||
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
|
||||
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
|
||||
|
@ -397,7 +372,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
|||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("test", "testnode")
|
||||
_, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
|
@ -412,21 +387,21 @@ func (s *Suite) TestSetTags(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(node)
|
||||
db.DB.Save(node)
|
||||
|
||||
// assign simple tags
|
||||
sTags := []string{"tag:test", "tag:foo"}
|
||||
err = db.SetTags(node, sTags)
|
||||
err = db.SetTags(node.ID, sTags)
|
||||
c.Assert(err, check.IsNil)
|
||||
node, err = db.GetNode("test", "testnode")
|
||||
node, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
|
||||
|
||||
// assign duplicat tags, expect no errors but no doubles in DB
|
||||
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||
err = db.SetTags(node, eTags)
|
||||
err = db.SetTags(node.ID, eTags)
|
||||
c.Assert(err, check.IsNil)
|
||||
node, err = db.GetNode("test", "testnode")
|
||||
node, err = db.getNode("test", "testnode")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(
|
||||
node.ForcedTags,
|
||||
|
@ -601,7 +576,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
|||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
}
|
||||
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
|
||||
sendUpdate, err := db.SaveNodeRoutes(&node)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -610,7 +585,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
|||
node0ByID, err := db.GetNodeByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.EnableAutoApprovedRoutes(pol, node0ByID)
|
||||
// TODO(kradalby): Check state update
|
||||
_, err = db.EnableAutoApprovedRoutes(pol, node0ByID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes, err := db.GetEnabledRoutes(node0ByID)
|
||||
|
|
|
@ -20,7 +20,6 @@ var (
|
|||
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
|
||||
)
|
||||
|
||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||
userName string,
|
||||
reusable bool,
|
||||
|
@ -28,11 +27,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
|||
expiration *time.Time,
|
||||
aclTags []string,
|
||||
) (*types.PreAuthKey, error) {
|
||||
// TODO(kradalby): figure out this lock
|
||||
// hsdb.mu.Lock()
|
||||
// defer hsdb.mu.Unlock()
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
|
||||
return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags)
|
||||
})
|
||||
}
|
||||
|
||||
user, err := hsdb.GetUser(userName)
|
||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||
func CreatePreAuthKey(
|
||||
tx *gorm.DB,
|
||||
userName string,
|
||||
reusable bool,
|
||||
ephemeral bool,
|
||||
expiration *time.Time,
|
||||
aclTags []string,
|
||||
) (*types.PreAuthKey, error) {
|
||||
user, err := GetUser(tx, userName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -48,7 +57,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
|||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
kstr, err := hsdb.generateKey()
|
||||
kstr, err := generateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -63,9 +72,8 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
|||
Expiration: expiration,
|
||||
}
|
||||
|
||||
err = hsdb.db.Transaction(func(db *gorm.DB) error {
|
||||
if err := db.Save(&key).Error; err != nil {
|
||||
return fmt.Errorf("failed to create key in the database: %w", err)
|
||||
if err := tx.Save(&key).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create key in the database: %w", err)
|
||||
}
|
||||
|
||||
if len(aclTags) > 0 {
|
||||
|
@ -73,8 +81,8 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
|||
|
||||
for _, tag := range aclTags {
|
||||
if !seenTags[tag] {
|
||||
if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to ceate key tag in the database: %w",
|
||||
err,
|
||||
)
|
||||
|
@ -84,9 +92,6 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
|||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -94,22 +99,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
|||
return &key, nil
|
||||
}
|
||||
|
||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.listPreAuthKeys(userName)
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||
return ListPreAuthKeys(rx, userName)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
||||
user, err := hsdb.getUser(userName)
|
||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
|
||||
user, err := GetUser(tx, userName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := []types.PreAuthKey{}
|
||||
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
||||
if err := tx.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -117,11 +121,8 @@ func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, er
|
|||
}
|
||||
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key.
|
||||
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
pak, err := hsdb.ValidatePreAuthKey(key)
|
||||
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
|
||||
pak, err := ValidatePreAuthKey(tx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -135,15 +136,8 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe
|
|||
|
||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||
// does not exist.
|
||||
func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.destroyPreAuthKey(pak)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
|
||||
return hsdb.db.Transaction(func(db *gorm.DB) error {
|
||||
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
|
||||
return tx.Transaction(func(db *gorm.DB) error {
|
||||
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
@ -156,12 +150,15 @@ func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
|
|||
})
|
||||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return ExpirePreAuthKey(tx, k)
|
||||
})
|
||||
}
|
||||
|
||||
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -169,26 +166,26 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
|||
}
|
||||
|
||||
// UsePreAuthKey marks a PreAuthKey as used.
|
||||
func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
k.Used = true
|
||||
if err := hsdb.db.Save(k).Error; err != nil {
|
||||
if err := tx.Save(k).Error; err != nil {
|
||||
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
|
||||
return ValidatePreAuthKey(rx, k)
|
||||
})
|
||||
}
|
||||
|
||||
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||
// If returns no error and a PreAuthKey, it can be used.
|
||||
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
|
||||
pak := types.PreAuthKey{}
|
||||
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
||||
if result := tx.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
|
@ -204,7 +201,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
|
|||
}
|
||||
|
||||
nodes := types.Nodes{}
|
||||
if err := hsdb.db.
|
||||
if err := tx.
|
||||
Preload("AuthKey").
|
||||
Where(&types.Node{AuthKeyID: uint(pak.ID)}).
|
||||
Find(&nodes).Error; err != nil {
|
||||
|
@ -218,7 +215,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
|
|||
return &pak, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) generateKey() (string, error) {
|
||||
func generateKey() (string, error) {
|
||||
size := 24
|
||||
bytes := make([]byte, size)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||
|
@ -41,7 +42,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
|||
user, err := db.CreateUser("test2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
now := time.Now()
|
||||
now := time.Now().Add(-5 * time.Second)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
|
@ -82,7 +83,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||
|
@ -103,7 +104,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -138,19 +139,22 @@ func (*Suite) TestEphemeralKey(c *check.C) {
|
|||
LastSeen: &now,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
|
||||
_, err = db.ValidatePreAuthKey(pak.Key)
|
||||
// Ephemeral keys are by definition reusable
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("test7", "testest")
|
||||
_, err = db.getNode("test7", "testest")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
db.ExpireEphemeralNodes(time.Second * 20)
|
||||
db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
ExpireEphemeralNodes(tx, time.Second*20)
|
||||
return nil
|
||||
})
|
||||
|
||||
// The machine record should have been deleted
|
||||
_, err = db.GetNode("test7", "testest")
|
||||
_, err = db.getNode("test7", "testest")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
|
@ -178,7 +182,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
|||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
pak.Used = true
|
||||
db.db.Save(&pak)
|
||||
db.DB.Save(&pak)
|
||||
|
||||
_, err = db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||
|
|
|
@ -7,23 +7,15 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||
|
||||
func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getRoutes()
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
|
||||
func GetRoutes(tx *gorm.DB) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Find(&routes).Error
|
||||
|
@ -34,9 +26,9 @@ func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
|
||||
func getAdvertisedAndEnabledRoutes(tx *gorm.DB) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("advertised = ? AND enabled = ?", true, true).
|
||||
|
@ -48,9 +40,9 @@ func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) {
|
||||
func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("prefix = ?", types.IPPrefix(pref)).
|
||||
|
@ -62,16 +54,9 @@ func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, erro
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getNodeAdvertisedRoutes(node)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) {
|
||||
func GetNodeAdvertisedRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("node_id = ? AND advertised = true", node.ID).
|
||||
|
@ -84,15 +69,14 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes,
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getNodeRoutes(node)
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||
return GetNodeRoutes(rx, node)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
|
||||
func GetNodeRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
Where("node_id = ?", node.ID).
|
||||
|
@ -104,16 +88,9 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getRoute(id)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
|
||||
func GetRoute(tx *gorm.DB, id uint64) (*types.Route, error) {
|
||||
var route types.Route
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Preload("Node.User").
|
||||
First(&route, id).Error
|
||||
|
@ -124,40 +101,34 @@ func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
|
|||
return &route, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) EnableRoute(id uint64) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.enableRoute(id)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) enableRoute(id uint64) error {
|
||||
route, err := hsdb.getRoute(id)
|
||||
func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
|
||||
route, err := GetRoute(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
if route.IsExitRoute() {
|
||||
return hsdb.enableRoutes(
|
||||
return enableRoutes(
|
||||
tx,
|
||||
&route.Node,
|
||||
types.ExitRouteV4.String(),
|
||||
types.ExitRouteV6.String(),
|
||||
)
|
||||
}
|
||||
|
||||
return hsdb.enableRoutes(&route.Node, netip.Prefix(route.Prefix).String())
|
||||
return enableRoutes(tx, &route.Node, netip.Prefix(route.Prefix).String())
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
route, err := hsdb.getRoute(id)
|
||||
func DisableRoute(tx *gorm.DB,
|
||||
id uint64,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
) (*types.StateUpdate, error) {
|
||||
route, err := GetRoute(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var routes types.Routes
|
||||
|
@ -166,64 +137,79 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
|||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
var update *types.StateUpdate
|
||||
if !route.IsExitRoute() {
|
||||
err = hsdb.failoverRouteWithNotify(route)
|
||||
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
route.Enabled = false
|
||||
route.IsPrimary = false
|
||||
err = hsdb.db.Save(route).Error
|
||||
err = tx.Save(route).Error
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
routes, err = hsdb.getNodeRoutes(&node)
|
||||
routes, err = GetNodeRoutes(tx, &node)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range routes {
|
||||
if routes[i].IsExitRoute() {
|
||||
routes[i].Enabled = false
|
||||
routes[i].IsPrimary = false
|
||||
err = hsdb.db.Save(&routes[i]).Error
|
||||
err = tx.Save(&routes[i]).Error
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if routes == nil {
|
||||
routes, err = hsdb.getNodeRoutes(&node)
|
||||
routes, err = GetNodeRoutes(tx, &node)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
node.Routes = routes
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
// If update is empty, it means that one was not created
|
||||
// by failover (as a failover was not necessary), create
|
||||
// one and return to the caller.
|
||||
if update == nil {
|
||||
update = &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{&node},
|
||||
ChangeNodes: types.Nodes{
|
||||
&node,
|
||||
},
|
||||
Message: "called from db.DisableRoute",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyAll(stateUpdate)
|
||||
}
|
||||
|
||||
return nil
|
||||
return update, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
func (hsdb *HSDatabase) DeleteRoute(
|
||||
id uint64,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
) (*types.StateUpdate, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return DeleteRoute(tx, id, isConnected)
|
||||
})
|
||||
}
|
||||
|
||||
route, err := hsdb.getRoute(id)
|
||||
func DeleteRoute(
|
||||
tx *gorm.DB,
|
||||
id uint64,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
) (*types.StateUpdate, error) {
|
||||
route, err := GetRoute(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var routes types.Routes
|
||||
|
@ -232,19 +218,20 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
|||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
var update *types.StateUpdate
|
||||
if !route.IsExitRoute() {
|
||||
err := hsdb.failoverRouteWithNotify(route)
|
||||
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
|
||||
return err
|
||||
if err := tx.Unscoped().Delete(&route).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
routes, err := hsdb.getNodeRoutes(&node)
|
||||
routes, err := GetNodeRoutes(tx, &node)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routesToDelete := types.Routes{}
|
||||
|
@ -254,56 +241,59 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
|||
}
|
||||
}
|
||||
|
||||
if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
|
||||
return err
|
||||
if err := tx.Unscoped().Delete(&routesToDelete).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// If update is empty, it means that one was not created
|
||||
// by failover (as a failover was not necessary), create
|
||||
// one and return to the caller.
|
||||
if routes == nil {
|
||||
routes, err = hsdb.getNodeRoutes(&node)
|
||||
routes, err = GetNodeRoutes(tx, &node)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
node.Routes = routes
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
if update == nil {
|
||||
update = &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{&node},
|
||||
ChangeNodes: types.Nodes{
|
||||
&node,
|
||||
},
|
||||
Message: "called from db.DeleteRoute",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyAll(stateUpdate)
|
||||
}
|
||||
|
||||
return nil
|
||||
return update, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error {
|
||||
routes, err := hsdb.getNodeRoutes(node)
|
||||
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error {
|
||||
routes, err := GetNodeRoutes(tx, node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range routes {
|
||||
if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil {
|
||||
if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(kradalby): This is a bit too aggressive, we could probably
|
||||
// figure out which routes needs to be failed over rather than all.
|
||||
hsdb.failoverRouteWithNotify(&routes[i])
|
||||
failoverRouteReturnUpdate(tx, isConnected, &routes[i])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isUniquePrefix returns if there is another node providing the same route already.
|
||||
func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
|
||||
func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
|
||||
var count int64
|
||||
hsdb.db.
|
||||
Model(&types.Route{}).
|
||||
tx.Model(&types.Route{}).
|
||||
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
|
||||
route.Prefix,
|
||||
route.NodeID,
|
||||
|
@ -312,9 +302,9 @@ func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
|
|||
return count == 0
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) {
|
||||
func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
|
||||
var route types.Route
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
|
||||
First(&route).Error
|
||||
|
@ -329,14 +319,17 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro
|
|||
return &route, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||
return GetNodePrimaryRoutes(rx, node)
|
||||
})
|
||||
}
|
||||
|
||||
// getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
||||
// Exit nodes are not considered for this, as they are never marked as Primary.
|
||||
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
func GetNodePrimaryRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
err := tx.
|
||||
Preload("Node").
|
||||
Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true).
|
||||
Find(&routes).Error
|
||||
|
@ -347,22 +340,21 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
// SaveNodeRoutes takes a node and updates the database with
|
||||
// the new routes.
|
||||
// It returns a bool wheter an update should be sent as the
|
||||
// saved route impacts nodes.
|
||||
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
return hsdb.saveNodeRoutes(node)
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (bool, error) {
|
||||
return SaveNodeRoutes(tx, node)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
||||
// SaveNodeRoutes takes a node and updates the database with
|
||||
// the new routes.
|
||||
// It returns a bool whether an update should be sent as the
|
||||
// saved route impacts nodes.
|
||||
func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
|
||||
sendUpdate := false
|
||||
|
||||
currentRoutes := types.Routes{}
|
||||
err := hsdb.db.Where("node_id = ?", node.ID).Find(¤tRoutes).Error
|
||||
err := tx.Where("node_id = ?", node.ID).Find(¤tRoutes).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
|
@ -382,7 +374,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
|||
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
|
||||
if !route.Advertised {
|
||||
currentRoutes[pos].Advertised = true
|
||||
err := hsdb.db.Save(¤tRoutes[pos]).Error
|
||||
err := tx.Save(¤tRoutes[pos]).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
|
@ -398,7 +390,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
|||
} else if route.Advertised {
|
||||
currentRoutes[pos].Advertised = false
|
||||
currentRoutes[pos].Enabled = false
|
||||
err := hsdb.db.Save(¤tRoutes[pos]).Error
|
||||
err := tx.Save(¤tRoutes[pos]).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
|
@ -413,7 +405,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
|||
Advertised: true,
|
||||
Enabled: false,
|
||||
}
|
||||
err := hsdb.db.Create(&route).Error
|
||||
err := tx.Create(&route).Error
|
||||
if err != nil {
|
||||
return sendUpdate, err
|
||||
}
|
||||
|
@ -425,127 +417,89 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
|||
|
||||
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route
|
||||
// currently have a functioning host that exposes the network.
|
||||
func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error {
|
||||
nodeRoutes, err := hsdb.getNodeRoutes(node)
|
||||
func EnsureFailoverRouteIsAvailable(
|
||||
tx *gorm.DB,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
node *types.Node,
|
||||
) (*types.StateUpdate, error) {
|
||||
nodeRoutes, err := GetNodeRoutes(tx, node)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var changedNodes types.Nodes
|
||||
for _, nodeRoute := range nodeRoutes {
|
||||
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix))
|
||||
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.IsPrimary {
|
||||
// if we have a primary route, and the node is connected
|
||||
// nothing needs to be done.
|
||||
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
|
||||
if isConnected[route.Node.MachineKey] {
|
||||
continue
|
||||
}
|
||||
|
||||
// if not, we need to failover the route
|
||||
err := hsdb.failoverRouteWithNotify(&route)
|
||||
update, err := failoverRouteReturnUpdate(tx, isConnected, &route)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update != nil {
|
||||
changedNodes = append(changedNodes, update.ChangeNodes...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error {
|
||||
routes, err := hsdb.getNodeRoutes(node)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var changedKeys []key.MachinePublic
|
||||
|
||||
for _, route := range routes {
|
||||
changed, err := hsdb.failoverRoute(&route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
changedKeys = append(changedKeys, changed...)
|
||||
}
|
||||
|
||||
changedKeys = lo.Uniq(changedKeys)
|
||||
|
||||
var nodes types.Nodes
|
||||
|
||||
for _, key := range changedKeys {
|
||||
node, err := hsdb.GetNodeByMachineKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
if nodes != nil {
|
||||
stateUpdate := types.StateUpdate{
|
||||
if len(changedNodes) != 0 {
|
||||
return &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: nodes,
|
||||
Message: "called from db.FailoverNodeRoutesWithNotify",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyAll(stateUpdate)
|
||||
}
|
||||
ChangeNodes: changedNodes,
|
||||
Message: "called from db.EnsureFailoverRouteIsAvailable",
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
|
||||
changedKeys, err := hsdb.failoverRoute(r)
|
||||
func failoverRouteReturnUpdate(
|
||||
tx *gorm.DB,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
r *types.Route,
|
||||
) (*types.StateUpdate, error) {
|
||||
changedKeys, err := failoverRoute(tx, isConnected, r)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Interface("isConnected", isConnected).
|
||||
Interface("changedKeys", changedKeys).
|
||||
Msg("building route failover")
|
||||
|
||||
if len(changedKeys) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var nodes types.Nodes
|
||||
|
||||
log.Trace().
|
||||
Str("hostname", r.Node.Hostname).
|
||||
Msg("loading machines with new primary routes from db")
|
||||
|
||||
for _, key := range changedKeys {
|
||||
node, err := hsdb.getNodeByMachineKey(key)
|
||||
node, err := GetNodeByMachineKey(tx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("hostname", r.Node.Hostname).
|
||||
Msg("notifying peers about primary route change")
|
||||
|
||||
if nodes != nil {
|
||||
stateUpdate := types.StateUpdate{
|
||||
return &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: nodes,
|
||||
Message: "called from db.failoverRouteWithNotify",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
hsdb.notifier.NotifyAll(stateUpdate)
|
||||
}
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("hostname", r.Node.Hostname).
|
||||
Msg("notified peers about primary route change")
|
||||
|
||||
return nil
|
||||
Message: "called from db.failoverRouteReturnUpdate",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// failoverRoute takes a route that is no longer available,
|
||||
|
@ -556,12 +510,16 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
|
|||
//
|
||||
// and tries to find a new route to take over its place.
|
||||
// If the given route was not primary, it returns early.
|
||||
func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) {
|
||||
func failoverRoute(
|
||||
tx *gorm.DB,
|
||||
isConnected map[key.MachinePublic]bool,
|
||||
r *types.Route,
|
||||
) ([]key.MachinePublic, error) {
|
||||
if r == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// This route is not a primary route, and it isnt
|
||||
// This route is not a primary route, and it is not
|
||||
// being served to nodes.
|
||||
if !r.IsPrimary {
|
||||
return nil, nil
|
||||
|
@ -572,7 +530,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix))
|
||||
routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -585,14 +543,18 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
|||
continue
|
||||
}
|
||||
|
||||
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
|
||||
if !route.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if isConnected[route.Node.MachineKey] {
|
||||
newPrimary = &routes[idx]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If a new route was not found/available,
|
||||
// return with an error.
|
||||
// return without an error.
|
||||
// We do not want to update the database as
|
||||
// the one currently marked as primary is the
|
||||
// best we got.
|
||||
|
@ -606,7 +568,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
|||
|
||||
// Remove primary from the old route
|
||||
r.IsPrimary = false
|
||||
err = hsdb.db.Save(&r).Error
|
||||
err = tx.Save(&r).Error
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error disabling new primary route")
|
||||
|
||||
|
@ -619,7 +581,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
|||
|
||||
// Set primary for the new primary
|
||||
newPrimary.IsPrimary = true
|
||||
err = hsdb.db.Save(&newPrimary).Error
|
||||
err = tx.Save(&newPrimary).Error
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error enabling new primary route")
|
||||
|
||||
|
@ -634,19 +596,26 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
|||
return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
|
||||
}
|
||||
|
||||
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||
aclPolicy *policy.ACLPolicy,
|
||||
node *types.Node,
|
||||
) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
if len(node.IPAddresses) == 0 {
|
||||
return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||
) (*types.StateUpdate, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
|
||||
})
|
||||
}
|
||||
|
||||
routes, err := hsdb.getNodeAdvertisedRoutes(node)
|
||||
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
||||
func EnableAutoApprovedRoutes(
|
||||
tx *gorm.DB,
|
||||
aclPolicy *policy.ACLPolicy,
|
||||
node *types.Node,
|
||||
) (*types.StateUpdate, error) {
|
||||
if len(node.IPAddresses) == 0 {
|
||||
return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||
}
|
||||
|
||||
routes, err := GetNodeAdvertisedRoutes(tx, node)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -654,9 +623,11 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
|||
Str("node", node.Hostname).
|
||||
Msg("Could not get advertised routes for node")
|
||||
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
|
||||
|
||||
approvedRoutes := types.Routes{}
|
||||
|
||||
for _, advertisedRoute := range routes {
|
||||
|
@ -673,9 +644,16 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
|||
Uint64("nodeId", node.ID).
|
||||
Msg("Failed to resolve autoApprovers for advertised route")
|
||||
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
Str("user", node.User.Name).
|
||||
Strs("routeApprovers", routeApprovers).
|
||||
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
|
||||
Msg("looking up route for autoapproving")
|
||||
|
||||
for _, approvedAlias := range routeApprovers {
|
||||
if approvedAlias == node.User.Name {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
|
@ -687,7 +665,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
|||
Str("alias", approvedAlias).
|
||||
Msg("Failed to expand alias when processing autoApprovers policy")
|
||||
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||
|
@ -698,17 +676,25 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
|||
}
|
||||
}
|
||||
|
||||
update := &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{},
|
||||
Message: "created in db.EnableAutoApprovedRoutes",
|
||||
}
|
||||
|
||||
for _, approvedRoute := range approvedRoutes {
|
||||
err := hsdb.enableRoute(uint64(approvedRoute.ID))
|
||||
perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID))
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Str("approvedRoute", approvedRoute.String()).
|
||||
Uint64("nodeId", node.ID).
|
||||
Msg("Failed to enable approved route")
|
||||
|
||||
return err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil
|
||||
update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...)
|
||||
}
|
||||
|
||||
return update, nil
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
|||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("test", "test_get_route_node")
|
||||
_, err = db.getNode("test", "test_get_route_node")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netip.ParsePrefix("10.0.0.0/24")
|
||||
|
@ -42,7 +42,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
|||
AuthKeyID: uint(pak.ID),
|
||||
Hostinfo: &hostInfo,
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
|
||||
su, err := db.SaveNodeRoutes(&node)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -52,10 +52,11 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(advertisedRoutes), check.Equals, 1)
|
||||
|
||||
err = db.enableRoutes(&node, "192.168.0.0/24")
|
||||
// TODO(kradalby): check state update
|
||||
_, err = db.enableRoutes(&node, "192.168.0.0/24")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||
_, err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
|
@ -66,7 +67,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
|||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("test", "test_enable_route_node")
|
||||
_, err = db.getNode("test", "test_enable_route_node")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netip.ParsePrefix(
|
||||
|
@ -91,7 +92,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
|||
AuthKeyID: uint(pak.ID),
|
||||
Hostinfo: &hostInfo,
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
|
||||
sendUpdate, err := db.SaveNodeRoutes(&node)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -106,10 +107,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
||||
|
||||
err = db.enableRoutes(&node, "192.168.0.0/24")
|
||||
_, err = db.enableRoutes(&node, "192.168.0.0/24")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||
_, err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes, err := db.GetEnabledRoutes(&node)
|
||||
|
@ -117,14 +118,14 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
|||
c.Assert(len(enabledRoutes), check.Equals, 1)
|
||||
|
||||
// Adding it twice will just let it pass through
|
||||
err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||
_, err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
||||
|
||||
err = db.enableRoutes(&node, "150.0.10.0/25")
|
||||
_, err = db.enableRoutes(&node, "150.0.10.0/25")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node)
|
||||
|
@ -139,7 +140,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
|||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("test", "test_enable_route_node")
|
||||
_, err = db.getNode("test", "test_enable_route_node")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netip.ParsePrefix(
|
||||
|
@ -163,16 +164,16 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
|||
AuthKeyID: uint(pak.ID),
|
||||
Hostinfo: &hostInfo1,
|
||||
}
|
||||
db.db.Save(&node1)
|
||||
db.DB.Save(&node1)
|
||||
|
||||
sendUpdate, err := db.SaveNodeRoutes(&node1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(sendUpdate, check.Equals, false)
|
||||
|
||||
err = db.enableRoutes(&node1, route.String())
|
||||
_, err = db.enableRoutes(&node1, route.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&node1, route2.String())
|
||||
_, err = db.enableRoutes(&node1, route2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hostInfo2 := tailcfg.Hostinfo{
|
||||
|
@ -186,13 +187,13 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
|||
AuthKeyID: uint(pak.ID),
|
||||
Hostinfo: &hostInfo2,
|
||||
}
|
||||
db.db.Save(&node2)
|
||||
db.DB.Save(&node2)
|
||||
|
||||
sendUpdate, err = db.SaveNodeRoutes(&node2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(sendUpdate, check.Equals, false)
|
||||
|
||||
err = db.enableRoutes(&node2, route2.String())
|
||||
_, err = db.enableRoutes(&node2, route2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
|
||||
|
@ -219,7 +220,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
|||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNode("test", "test_enable_route_node")
|
||||
_, err = db.getNode("test", "test_enable_route_node")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
prefix, err := netip.ParsePrefix(
|
||||
|
@ -246,22 +247,23 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
|||
Hostinfo: &hostInfo1,
|
||||
LastSeen: &now,
|
||||
}
|
||||
db.db.Save(&node1)
|
||||
db.DB.Save(&node1)
|
||||
|
||||
sendUpdate, err := db.SaveNodeRoutes(&node1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(sendUpdate, check.Equals, false)
|
||||
|
||||
err = db.enableRoutes(&node1, prefix.String())
|
||||
_, err = db.enableRoutes(&node1, prefix.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&node1, prefix2.String())
|
||||
_, err = db.enableRoutes(&node1, prefix2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err := db.GetNodeRoutes(&node1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.DeleteRoute(uint64(routes[0].ID))
|
||||
// TODO(kradalby): check stateupdate
|
||||
_, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
|
||||
|
@ -269,17 +271,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
|||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||
}
|
||||
|
||||
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
|
||||
|
||||
func TestFailoverRoute(t *testing.T) {
|
||||
ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
|
||||
|
||||
// TODO(kradalby): Count/verify updates
|
||||
var sink chan types.StateUpdate
|
||||
|
||||
go func() {
|
||||
for range sink {
|
||||
}
|
||||
}()
|
||||
|
||||
machineKeys := []key.MachinePublic{
|
||||
key.NewMachine().Public(),
|
||||
key.NewMachine().Public(),
|
||||
|
@ -291,6 +285,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
name string
|
||||
failingRoute types.Route
|
||||
routes types.Routes
|
||||
isConnected map[key.MachinePublic]bool
|
||||
want []key.MachinePublic
|
||||
wantErr bool
|
||||
}{
|
||||
|
@ -371,6 +366,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[0],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
routes: types.Routes{
|
||||
types.Route{
|
||||
|
@ -382,6 +378,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[0],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
types.Route{
|
||||
Model: gorm.Model{
|
||||
|
@ -392,8 +389,13 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[1],
|
||||
},
|
||||
IsPrimary: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
isConnected: map[key.MachinePublic]bool{
|
||||
machineKeys[0]: false,
|
||||
machineKeys[1]: true,
|
||||
},
|
||||
want: []key.MachinePublic{
|
||||
machineKeys[0],
|
||||
machineKeys[1],
|
||||
|
@ -411,6 +413,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[0],
|
||||
},
|
||||
IsPrimary: false,
|
||||
Enabled: true,
|
||||
},
|
||||
routes: types.Routes{
|
||||
types.Route{
|
||||
|
@ -422,6 +425,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[0],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
types.Route{
|
||||
Model: gorm.Model{
|
||||
|
@ -432,6 +436,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[1],
|
||||
},
|
||||
IsPrimary: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
want: nil,
|
||||
|
@ -448,6 +453,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[1],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
routes: types.Routes{
|
||||
types.Route{
|
||||
|
@ -459,6 +465,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[0],
|
||||
},
|
||||
IsPrimary: false,
|
||||
Enabled: true,
|
||||
},
|
||||
types.Route{
|
||||
Model: gorm.Model{
|
||||
|
@ -469,6 +476,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[1],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
types.Route{
|
||||
Model: gorm.Model{
|
||||
|
@ -479,8 +487,14 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[2],
|
||||
},
|
||||
IsPrimary: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
isConnected: map[key.MachinePublic]bool{
|
||||
machineKeys[0]: true,
|
||||
machineKeys[1]: true,
|
||||
machineKeys[2]: true,
|
||||
},
|
||||
want: []key.MachinePublic{
|
||||
machineKeys[1],
|
||||
machineKeys[0],
|
||||
|
@ -498,6 +512,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[0],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
routes: types.Routes{
|
||||
types.Route{
|
||||
|
@ -509,6 +524,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[0],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
// Offline
|
||||
types.Route{
|
||||
|
@ -520,8 +536,13 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[3],
|
||||
},
|
||||
IsPrimary: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
isConnected: map[key.MachinePublic]bool{
|
||||
machineKeys[0]: true,
|
||||
machineKeys[3]: false,
|
||||
},
|
||||
want: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
|
@ -536,6 +557,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[0],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
routes: types.Routes{
|
||||
types.Route{
|
||||
|
@ -547,6 +569,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[0],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
// Offline
|
||||
types.Route{
|
||||
|
@ -558,6 +581,7 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[3],
|
||||
},
|
||||
IsPrimary: false,
|
||||
Enabled: true,
|
||||
},
|
||||
types.Route{
|
||||
Model: gorm.Model{
|
||||
|
@ -568,14 +592,61 @@ func TestFailoverRoute(t *testing.T) {
|
|||
MachineKey: machineKeys[1],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
isConnected: map[key.MachinePublic]bool{
|
||||
machineKeys[0]: false,
|
||||
machineKeys[1]: true,
|
||||
machineKeys[3]: false,
|
||||
},
|
||||
want: []key.MachinePublic{
|
||||
machineKeys[0],
|
||||
machineKeys[1],
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "failover-primary-none-enabled",
|
||||
failingRoute: types.Route{
|
||||
Model: gorm.Model{
|
||||
ID: 1,
|
||||
},
|
||||
Prefix: ipp("10.0.0.0/24"),
|
||||
Node: types.Node{
|
||||
MachineKey: machineKeys[0],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
routes: types.Routes{
|
||||
types.Route{
|
||||
Model: gorm.Model{
|
||||
ID: 1,
|
||||
},
|
||||
Prefix: ipp("10.0.0.0/24"),
|
||||
Node: types.Node{
|
||||
MachineKey: machineKeys[0],
|
||||
},
|
||||
IsPrimary: true,
|
||||
Enabled: true,
|
||||
},
|
||||
// not enabled
|
||||
types.Route{
|
||||
Model: gorm.Model{
|
||||
ID: 2,
|
||||
},
|
||||
Prefix: ipp("10.0.0.0/24"),
|
||||
Node: types.Node{
|
||||
MachineKey: machineKeys[1],
|
||||
},
|
||||
IsPrimary: false,
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
want: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -583,13 +654,14 @@ func TestFailoverRoute(t *testing.T) {
|
|||
tmpDir, err := os.MkdirTemp("", "failover-db-test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
notif := notifier.NewNotifier()
|
||||
|
||||
db, err = NewHeadscaleDatabase(
|
||||
"sqlite3",
|
||||
tmpDir+"/headscale_test.db",
|
||||
false,
|
||||
notif,
|
||||
types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
notifier.NewNotifier(),
|
||||
[]netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
|
@ -597,23 +669,15 @@ func TestFailoverRoute(t *testing.T) {
|
|||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Pretend that all the nodes are connected to control
|
||||
for idx, key := range machineKeys {
|
||||
// Pretend one node is offline
|
||||
if idx == 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
notif.AddNode(key, sink)
|
||||
}
|
||||
|
||||
for _, route := range tt.routes {
|
||||
if err := db.db.Save(&route).Error; err != nil {
|
||||
if err := db.DB.Save(&route).Error; err != nil {
|
||||
t.Fatalf("failed to create route: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
got, err := db.failoverRoute(&tt.failingRoute)
|
||||
got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) {
|
||||
return failoverRoute(tx, tt.isConnected, &tt.failingRoute)
|
||||
})
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
@ -627,3 +691,231 @@ func TestFailoverRoute(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// func TestDisableRouteFailover(t *testing.T) {
|
||||
// machineKeys := []key.MachinePublic{
|
||||
// key.NewMachine().Public(),
|
||||
// key.NewMachine().Public(),
|
||||
// key.NewMachine().Public(),
|
||||
// key.NewMachine().Public(),
|
||||
// }
|
||||
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// nodes types.Nodes
|
||||
|
||||
// routeID uint64
|
||||
// isConnected map[key.MachinePublic]bool
|
||||
|
||||
// wantMachineKey key.MachinePublic
|
||||
// wantErr string
|
||||
// }{
|
||||
// {
|
||||
// name: "single-route",
|
||||
// nodes: types.Nodes{
|
||||
// &types.Node{
|
||||
// ID: 0,
|
||||
// MachineKey: machineKeys[0],
|
||||
// Routes: []types.Route{
|
||||
// {
|
||||
// Model: gorm.Model{
|
||||
// ID: 1,
|
||||
// },
|
||||
// Prefix: ipp("10.0.0.0/24"),
|
||||
// Node: types.Node{
|
||||
// MachineKey: machineKeys[0],
|
||||
// },
|
||||
// IsPrimary: true,
|
||||
// },
|
||||
// },
|
||||
// Hostinfo: &tailcfg.Hostinfo{
|
||||
// RoutableIPs: []netip.Prefix{
|
||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// routeID: 1,
|
||||
// wantMachineKey: machineKeys[0],
|
||||
// },
|
||||
// {
|
||||
// name: "failover-simple",
|
||||
// nodes: types.Nodes{
|
||||
// &types.Node{
|
||||
// ID: 0,
|
||||
// MachineKey: machineKeys[0],
|
||||
// Routes: []types.Route{
|
||||
// {
|
||||
// Model: gorm.Model{
|
||||
// ID: 1,
|
||||
// },
|
||||
// Prefix: ipp("10.0.0.0/24"),
|
||||
// IsPrimary: true,
|
||||
// },
|
||||
// },
|
||||
// Hostinfo: &tailcfg.Hostinfo{
|
||||
// RoutableIPs: []netip.Prefix{
|
||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// &types.Node{
|
||||
// ID: 1,
|
||||
// MachineKey: machineKeys[1],
|
||||
// Routes: []types.Route{
|
||||
// {
|
||||
// Model: gorm.Model{
|
||||
// ID: 2,
|
||||
// },
|
||||
// Prefix: ipp("10.0.0.0/24"),
|
||||
// IsPrimary: false,
|
||||
// },
|
||||
// },
|
||||
// Hostinfo: &tailcfg.Hostinfo{
|
||||
// RoutableIPs: []netip.Prefix{
|
||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// routeID: 1,
|
||||
// wantMachineKey: machineKeys[1],
|
||||
// },
|
||||
// {
|
||||
// name: "no-failover-offline",
|
||||
// nodes: types.Nodes{
|
||||
// &types.Node{
|
||||
// ID: 0,
|
||||
// MachineKey: machineKeys[0],
|
||||
// Routes: []types.Route{
|
||||
// {
|
||||
// Model: gorm.Model{
|
||||
// ID: 1,
|
||||
// },
|
||||
// Prefix: ipp("10.0.0.0/24"),
|
||||
// IsPrimary: true,
|
||||
// },
|
||||
// },
|
||||
// Hostinfo: &tailcfg.Hostinfo{
|
||||
// RoutableIPs: []netip.Prefix{
|
||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// &types.Node{
|
||||
// ID: 1,
|
||||
// MachineKey: machineKeys[1],
|
||||
// Routes: []types.Route{
|
||||
// {
|
||||
// Model: gorm.Model{
|
||||
// ID: 2,
|
||||
// },
|
||||
// Prefix: ipp("10.0.0.0/24"),
|
||||
// IsPrimary: false,
|
||||
// },
|
||||
// },
|
||||
// Hostinfo: &tailcfg.Hostinfo{
|
||||
// RoutableIPs: []netip.Prefix{
|
||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// isConnected: map[key.MachinePublic]bool{
|
||||
// machineKeys[0]: true,
|
||||
// machineKeys[1]: false,
|
||||
// },
|
||||
// routeID: 1,
|
||||
// wantMachineKey: machineKeys[1],
|
||||
// },
|
||||
// {
|
||||
// name: "failover-to-online",
|
||||
// nodes: types.Nodes{
|
||||
// &types.Node{
|
||||
// ID: 0,
|
||||
// MachineKey: machineKeys[0],
|
||||
// Routes: []types.Route{
|
||||
// {
|
||||
// Model: gorm.Model{
|
||||
// ID: 1,
|
||||
// },
|
||||
// Prefix: ipp("10.0.0.0/24"),
|
||||
// IsPrimary: true,
|
||||
// },
|
||||
// },
|
||||
// Hostinfo: &tailcfg.Hostinfo{
|
||||
// RoutableIPs: []netip.Prefix{
|
||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// &types.Node{
|
||||
// ID: 1,
|
||||
// MachineKey: machineKeys[1],
|
||||
// Routes: []types.Route{
|
||||
// {
|
||||
// Model: gorm.Model{
|
||||
// ID: 2,
|
||||
// },
|
||||
// Prefix: ipp("10.0.0.0/24"),
|
||||
// IsPrimary: false,
|
||||
// },
|
||||
// },
|
||||
// Hostinfo: &tailcfg.Hostinfo{
|
||||
// RoutableIPs: []netip.Prefix{
|
||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// isConnected: map[key.MachinePublic]bool{
|
||||
// machineKeys[0]: true,
|
||||
// machineKeys[1]: true,
|
||||
// },
|
||||
// routeID: 1,
|
||||
// wantMachineKey: machineKeys[1],
|
||||
// },
|
||||
// }
|
||||
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "")
|
||||
// assert.NoError(t, err)
|
||||
|
||||
// // bootstrap db
|
||||
// datab.DB.Transaction(func(tx *gorm.DB) error {
|
||||
// for _, node := range tt.nodes {
|
||||
// err := tx.Save(node).Error
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// _, err = SaveNodeRoutes(tx, node)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// })
|
||||
|
||||
// got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
// return DisableRoute(tx, tt.routeID, tt.isConnected)
|
||||
// })
|
||||
|
||||
// // if (err.Error() != "") != tt.wantErr {
|
||||
// // t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
// // return
|
||||
// // }
|
||||
|
||||
// if len(got.ChangeNodes) != 1 {
|
||||
// t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes))
|
||||
// }
|
||||
|
||||
// if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" {
|
||||
// t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
|
@ -45,9 +46,12 @@ func (s *Suite) ResetDB(c *check.C) {
|
|||
log.Printf("database path: %s", tmpDir+"/headscale_test.db")
|
||||
|
||||
db, err = NewHeadscaleDatabase(
|
||||
"sqlite3",
|
||||
tmpDir+"/headscale_test.db",
|
||||
false,
|
||||
types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
notifier.NewNotifier(),
|
||||
[]netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
|
|
|
@ -15,22 +15,25 @@ var (
|
|||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
)
|
||||
|
||||
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.User, error) {
|
||||
return CreateUser(tx, name)
|
||||
})
|
||||
}
|
||||
|
||||
// CreateUser creates a new User. Returns error if could not be created
|
||||
// or another user already exists.
|
||||
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||
err := util.CheckForFQDNRules(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user := types.User{}
|
||||
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
|
||||
if err := tx.Where("name = ?", name).First(&user).Error; err == nil {
|
||||
return nil, ErrUserExists
|
||||
}
|
||||
user.Name = name
|
||||
if err := hsdb.db.Create(&user).Error; err != nil {
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
log.Error().
|
||||
Str("func", "CreateUser").
|
||||
Err(err).
|
||||
|
@ -42,18 +45,21 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
|||
return &user, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return DestroyUser(tx, name)
|
||||
})
|
||||
}
|
||||
|
||||
// DestroyUser destroys a User. Returns error if the User does
|
||||
// not exist or if there are nodes associated with it.
|
||||
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
user, err := hsdb.getUser(name)
|
||||
func DestroyUser(tx *gorm.DB, name string) error {
|
||||
user, err := GetUser(tx, name)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
nodes, err := hsdb.listNodesByUser(name)
|
||||
nodes, err := ListNodesByUser(tx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -61,32 +67,35 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
|
|||
return ErrUserStillHasNodes
|
||||
}
|
||||
|
||||
keys, err := hsdb.listPreAuthKeys(name)
|
||||
keys, err := ListPreAuthKeys(tx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range keys {
|
||||
err = hsdb.destroyPreAuthKey(key)
|
||||
err = DestroyPreAuthKey(tx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil {
|
||||
if result := tx.Unscoped().Delete(&user); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return RenameUser(tx, oldName, newName)
|
||||
})
|
||||
}
|
||||
|
||||
// RenameUser renames a User. Returns error if the User does
|
||||
// not exist or if another User exists with the new name.
|
||||
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
||||
var err error
|
||||
oldUser, err := hsdb.getUser(oldName)
|
||||
oldUser, err := GetUser(tx, oldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -94,7 +103,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = hsdb.getUser(newName)
|
||||
_, err = GetUser(tx, newName)
|
||||
if err == nil {
|
||||
return ErrUserExists
|
||||
}
|
||||
|
@ -104,24 +113,22 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
|||
|
||||
oldUser.Name = newName
|
||||
|
||||
if result := hsdb.db.Save(&oldUser); result.Error != nil {
|
||||
if result := tx.Save(&oldUser); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUser fetches a user by name.
|
||||
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getUser(name)
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||
return GetUser(rx, name)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
|
||||
func GetUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||
user := types.User{}
|
||||
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
|
||||
if result := tx.First(&user, "name = ?", name); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
|
@ -131,17 +138,16 @@ func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
|
|||
return &user, nil
|
||||
}
|
||||
|
||||
// ListUsers gets all the existing users.
|
||||
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.listUsers()
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||||
return ListUsers(rx)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
|
||||
// ListUsers gets all the existing users.
|
||||
func ListUsers(tx *gorm.DB) ([]types.User, error) {
|
||||
users := []types.User{}
|
||||
if err := hsdb.db.Find(&users).Error; err != nil {
|
||||
if err := tx.Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -149,46 +155,42 @@ func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
|
|||
}
|
||||
|
||||
// ListNodesByUser gets all the nodes in a given user.
|
||||
func (hsdb *HSDatabase) ListNodesByUser(name string) (types.Nodes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.listNodesByUser(name)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) listNodesByUser(name string) (types.Nodes, error) {
|
||||
func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
|
||||
err := util.CheckForFQDNRules(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user, err := hsdb.getUser(name)
|
||||
user, err := GetUser(tx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes := types.Nodes{}
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
|
||||
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// AssignNodeToUser assigns a Node to a user.
|
||||
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return AssignNodeToUser(tx, node, username)
|
||||
})
|
||||
}
|
||||
|
||||
// AssignNodeToUser assigns a Node to a user.
|
||||
func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
|
||||
err := util.CheckForFQDNRules(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user, err := hsdb.getUser(username)
|
||||
user, err := GetUser(tx, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node.User = *user
|
||||
if result := hsdb.db.Save(&node); result.Error != nil {
|
||||
if result := tx.Save(&node); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
|||
err = db.DestroyUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
result := db.db.Preload("User").First(&pak, "key = ?", pak.Key)
|
||||
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
|
||||
// destroying a user also deletes all associated preauthkeys
|
||||
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
||||
|
||||
|
@ -53,7 +53,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
|
||||
err = db.DestroyUser("test")
|
||||
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
||||
|
@ -105,7 +105,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
|
|||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&node)
|
||||
db.DB.Save(&node)
|
||||
c.Assert(node.UserID, check.Equals, oldUser.ID)
|
||||
|
||||
err = db.AssignNodeToUser(&node, newUser.Name)
|
||||
|
|
|
@ -84,6 +84,8 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
|||
RegionID: d.cfg.ServerRegionID,
|
||||
HostName: host,
|
||||
DERPPort: port,
|
||||
IPv4: d.cfg.IPv4,
|
||||
IPv6: d.cfg.IPv6,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -99,6 +101,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
|||
localDERPregion.Nodes[0].STUNPort = portSTUN
|
||||
|
||||
log.Info().Caller().Msgf("DERP region: %+v", localDERPregion)
|
||||
log.Info().Caller().Msgf("DERP Nodes[0]: %+v", localDERPregion.Nodes[0])
|
||||
|
||||
return localDERPregion, nil
|
||||
}
|
||||
|
@ -208,6 +211,7 @@ func DERPProbeHandler(
|
|||
// The initial implementation is here https://github.com/tailscale/tailscale/pull/1406
|
||||
// They have a cache, but not clear if that is really necessary at Headscale, uh, scale.
|
||||
// An example implementation is found here https://derp.tailscale.com/bootstrap-dns
|
||||
// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL.
|
||||
func DERPBootstrapDNSHandler(
|
||||
derpMap *tailcfg.DERPMap,
|
||||
) func(http.ResponseWriter, *http.Request) {
|
||||
|
|
|
@ -8,11 +8,13 @@ import (
|
|||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
@ -136,12 +138,14 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
|
|||
ctx context.Context,
|
||||
request *v1.ExpirePreAuthKeyRequest,
|
||||
) (*v1.ExpirePreAuthKeyResponse, error) {
|
||||
preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key)
|
||||
err := api.h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
err = api.h.db.ExpirePreAuthKey(preAuthKey)
|
||||
return db.ExpirePreAuthKey(tx, preAuthKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -181,17 +185,31 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
node, err := api.h.db.RegisterNodeFromAuthCallback(
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
return db.RegisterNodeFromAuthCallback(
|
||||
tx,
|
||||
api.h.registrationCache,
|
||||
mkey,
|
||||
request.GetUser(),
|
||||
nil,
|
||||
util.RegisterMethodCLI,
|
||||
api.h.cfg.IPPrefixes,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
Message: "called from api.RegisterNode",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-registernode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
|
||||
|
@ -217,25 +235,35 @@ func (api headscaleV1APIServer) SetTags(
|
|||
ctx context.Context,
|
||||
request *v1.SetTagsRequest,
|
||||
) (*v1.SetTagsResponse, error) {
|
||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
||||
for _, tag := range request.GetTags() {
|
||||
err := validateTag(tag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
err := db.SetTags(tx, request.GetNodeId(), request.GetTags())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, tag := range request.GetTags() {
|
||||
err := validateTag(tag)
|
||||
return db.GetNodeByID(tx, request.GetNodeId())
|
||||
})
|
||||
if err != nil {
|
||||
return &v1.SetTagsResponse{
|
||||
Node: nil,
|
||||
}, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
err = api.h.db.SetTags(node, request.GetTags())
|
||||
if err != nil {
|
||||
return &v1.SetTagsResponse{
|
||||
Node: nil,
|
||||
}, status.Error(codes.Internal, err.Error())
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
Message: "called from api.SetTags",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
|
@ -270,11 +298,21 @@ func (api headscaleV1APIServer) DeleteNode(
|
|||
|
||||
err = api.h.db.DeleteNode(
|
||||
node,
|
||||
api.h.nodeNotifier.ConnectedMap(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerRemoved,
|
||||
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, stateUpdate)
|
||||
}
|
||||
|
||||
return &v1.DeleteNodeResponse{}, nil
|
||||
}
|
||||
|
||||
|
@ -282,17 +320,38 @@ func (api headscaleV1APIServer) ExpireNode(
|
|||
ctx context.Context,
|
||||
request *v1.ExpireNodeRequest,
|
||||
) (*v1.ExpireNodeResponse, error) {
|
||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
||||
now := time.Now()
|
||||
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
db.NodeSetExpiry(
|
||||
tx,
|
||||
request.GetNodeId(),
|
||||
now,
|
||||
)
|
||||
|
||||
return db.GetNodeByID(tx, request.GetNodeId())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
selfUpdate := types.StateUpdate{
|
||||
Type: types.StateSelfUpdate,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
}
|
||||
if selfUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByMachineKey(
|
||||
ctx,
|
||||
selfUpdate,
|
||||
node.MachineKey)
|
||||
}
|
||||
|
||||
api.h.db.NodeSetExpiry(
|
||||
node,
|
||||
now,
|
||||
)
|
||||
stateUpdate := types.StateUpdateExpire(node.ID, now)
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
|
@ -306,19 +365,32 @@ func (api headscaleV1APIServer) RenameNode(
|
|||
ctx context.Context,
|
||||
request *v1.RenameNodeRequest,
|
||||
) (*v1.RenameNodeResponse, error) {
|
||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = api.h.db.RenameNode(
|
||||
node,
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
err := db.RenameNode(
|
||||
tx,
|
||||
request.GetNodeId(),
|
||||
request.GetNewName(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.GetNodeByID(tx, request.GetNodeId())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
Message: "called from api.RenameNode",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
Str("new_name", request.GetNewName()).
|
||||
|
@ -331,8 +403,11 @@ func (api headscaleV1APIServer) ListNodes(
|
|||
ctx context.Context,
|
||||
request *v1.ListNodesRequest,
|
||||
) (*v1.ListNodesResponse, error) {
|
||||
isConnected := api.h.nodeNotifier.ConnectedMap()
|
||||
if request.GetUser() != "" {
|
||||
nodes, err := api.h.db.ListNodesByUser(request.GetUser())
|
||||
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return db.ListNodesByUser(rx, request.GetUser())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -343,7 +418,7 @@ func (api headscaleV1APIServer) ListNodes(
|
|||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
|
||||
resp.Online = isConnected[node.MachineKey]
|
||||
|
||||
response[index] = resp
|
||||
}
|
||||
|
@ -362,10 +437,10 @@ func (api headscaleV1APIServer) ListNodes(
|
|||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
|
||||
resp.Online = isConnected[node.MachineKey]
|
||||
|
||||
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
|
||||
&node,
|
||||
node,
|
||||
)
|
||||
resp.InvalidTags = invalidTags
|
||||
resp.ValidTags = validTags
|
||||
|
@ -396,7 +471,9 @@ func (api headscaleV1APIServer) GetRoutes(
|
|||
ctx context.Context,
|
||||
request *v1.GetRoutesRequest,
|
||||
) (*v1.GetRoutesResponse, error) {
|
||||
routes, err := api.h.db.GetRoutes()
|
||||
routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||
return db.GetRoutes(rx)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -410,11 +487,19 @@ func (api headscaleV1APIServer) EnableRoute(
|
|||
ctx context.Context,
|
||||
request *v1.EnableRouteRequest,
|
||||
) (*v1.EnableRouteResponse, error) {
|
||||
err := api.h.db.EnableRoute(request.GetRouteId())
|
||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return db.EnableRoute(tx, request.GetRouteId())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update != nil && update.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
|
||||
api.h.nodeNotifier.NotifyAll(
|
||||
ctx, *update)
|
||||
}
|
||||
|
||||
return &v1.EnableRouteResponse{}, nil
|
||||
}
|
||||
|
||||
|
@ -422,11 +507,19 @@ func (api headscaleV1APIServer) DisableRoute(
|
|||
ctx context.Context,
|
||||
request *v1.DisableRouteRequest,
|
||||
) (*v1.DisableRouteResponse, error) {
|
||||
err := api.h.db.DisableRoute(request.GetRouteId())
|
||||
isConnected := api.h.nodeNotifier.ConnectedMap()
|
||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return db.DisableRoute(tx, request.GetRouteId(), isConnected)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update != nil && update.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, *update)
|
||||
}
|
||||
|
||||
return &v1.DisableRouteResponse{}, nil
|
||||
}
|
||||
|
||||
|
@ -453,11 +546,19 @@ func (api headscaleV1APIServer) DeleteRoute(
|
|||
ctx context.Context,
|
||||
request *v1.DeleteRouteRequest,
|
||||
) (*v1.DeleteRouteResponse, error) {
|
||||
err := api.h.db.DeleteRoute(request.GetRouteId())
|
||||
isConnected := api.h.nodeNotifier.ConnectedMap()
|
||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update != nil && update.Valid() {
|
||||
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, *update)
|
||||
}
|
||||
|
||||
return &v1.DeleteRouteResponse{}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -272,13 +272,26 @@ func (m *Mapper) LiteMapResponse(
|
|||
mapRequest tailcfg.MapRequest,
|
||||
node *types.Node,
|
||||
pol *policy.ACLPolicy,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress)
|
||||
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
||||
pol,
|
||||
node,
|
||||
nodeMapToList(m.peers),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
|
||||
resp.SSHPolicy = sshPolicy
|
||||
|
||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||
}
|
||||
|
||||
func (m *Mapper) KeepAliveResponse(
|
||||
|
@ -380,9 +393,7 @@ func (m *Mapper) PeerChangedPatchResponse(
|
|||
}
|
||||
|
||||
if patches, ok := m.patches[uint64(change.NodeID)]; ok {
|
||||
patches := append(patches, p)
|
||||
|
||||
m.patches[uint64(change.NodeID)] = patches
|
||||
m.patches[uint64(change.NodeID)] = append(patches, p)
|
||||
} else {
|
||||
m.patches[uint64(change.NodeID)] = []patch{p}
|
||||
}
|
||||
|
@ -458,6 +469,8 @@ func (m *Mapper) marshalMapResponse(
|
|||
switch {
|
||||
case resp.Peers != nil && len(resp.Peers) > 0:
|
||||
responseType = "full"
|
||||
case isSelfUpdate(messages...):
|
||||
responseType = "self"
|
||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil:
|
||||
responseType = "lite"
|
||||
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
|
||||
|
@ -656,3 +669,13 @@ func appendPeerChanges(
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isSelfUpdate(messages ...string) bool {
|
||||
for _, message := range messages {
|
||||
if strings.Contains(message, types.SelfUpdateIdentifier) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -72,7 +72,7 @@ func tailNode(
|
|||
}
|
||||
|
||||
var derp string
|
||||
if node.Hostinfo.NetInfo != nil {
|
||||
if node.Hostinfo != nil && node.Hostinfo.NetInfo != nil {
|
||||
derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP)
|
||||
} else {
|
||||
derp = "127.3.3.40:0" // Zero means disconnected or unknown.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package notifier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -14,24 +15,28 @@ import (
|
|||
type Notifier struct {
|
||||
l sync.RWMutex
|
||||
nodes map[string]chan<- types.StateUpdate
|
||||
connected map[key.MachinePublic]bool
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
return &Notifier{
|
||||
nodes: make(map[string]chan<- types.StateUpdate),
|
||||
connected: make(map[key.MachinePublic]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
|
||||
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
|
||||
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to add node")
|
||||
defer log.Trace().
|
||||
Caller().
|
||||
Str("key", machineKey.ShortString()).
|
||||
Msg("releasing lock to add node")
|
||||
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
|
||||
if n.nodes == nil {
|
||||
n.nodes = make(map[string]chan<- types.StateUpdate)
|
||||
}
|
||||
|
||||
n.nodes[machineKey.String()] = c
|
||||
n.connected[machineKey] = true
|
||||
|
||||
log.Trace().
|
||||
Str("machine_key", machineKey.ShortString()).
|
||||
|
@ -41,16 +46,20 @@ func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpd
|
|||
|
||||
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
|
||||
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
|
||||
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to remove node")
|
||||
defer log.Trace().
|
||||
Caller().
|
||||
Str("key", machineKey.ShortString()).
|
||||
Msg("releasing lock to remove node")
|
||||
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
|
||||
if n.nodes == nil {
|
||||
if len(n.nodes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
delete(n.nodes, machineKey.String())
|
||||
n.connected[machineKey] = false
|
||||
|
||||
log.Trace().
|
||||
Str("machine_key", machineKey.ShortString()).
|
||||
|
@ -64,23 +73,28 @@ func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
|
|||
n.l.RLock()
|
||||
defer n.l.RUnlock()
|
||||
|
||||
if _, ok := n.nodes[machineKey.String()]; ok {
|
||||
return true
|
||||
return n.connected[machineKey]
|
||||
}
|
||||
|
||||
return false
|
||||
// TODO(kradalby): This returns a pointer and can be dangerous.
|
||||
func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool {
|
||||
return n.connected
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyAll(update types.StateUpdate) {
|
||||
n.NotifyWithIgnore(update)
|
||||
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
|
||||
n.NotifyWithIgnore(ctx, update)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) {
|
||||
func (n *Notifier) NotifyWithIgnore(
|
||||
ctx context.Context,
|
||||
update types.StateUpdate,
|
||||
ignore ...string,
|
||||
) {
|
||||
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
|
||||
defer log.Trace().
|
||||
Caller().
|
||||
Interface("type", update.Type).
|
||||
Msg("releasing lock, finished notifing")
|
||||
Msg("releasing lock, finished notifying")
|
||||
|
||||
n.l.RLock()
|
||||
defer n.l.RUnlock()
|
||||
|
@ -90,23 +104,58 @@ func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string)
|
|||
continue
|
||||
}
|
||||
|
||||
log.Trace().Caller().Str("machine", key).Strs("ignoring", ignore).Msg("sending update")
|
||||
c <- update
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Error().
|
||||
Err(ctx.Err()).
|
||||
Str("mkey", key).
|
||||
Any("origin", ctx.Value("origin")).
|
||||
Any("hostname", ctx.Value("hostname")).
|
||||
Msgf("update not sent, context cancelled")
|
||||
|
||||
return
|
||||
case c <- update:
|
||||
log.Trace().
|
||||
Str("mkey", key).
|
||||
Any("origin", ctx.Value("origin")).
|
||||
Any("hostname", ctx.Value("hostname")).
|
||||
Msgf("update successfully sent on chan")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyByMachineKey(update types.StateUpdate, mKey key.MachinePublic) {
|
||||
func (n *Notifier) NotifyByMachineKey(
|
||||
ctx context.Context,
|
||||
update types.StateUpdate,
|
||||
mKey key.MachinePublic,
|
||||
) {
|
||||
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
|
||||
defer log.Trace().
|
||||
Caller().
|
||||
Interface("type", update.Type).
|
||||
Msg("releasing lock, finished notifing")
|
||||
Msg("releasing lock, finished notifying")
|
||||
|
||||
n.l.RLock()
|
||||
defer n.l.RUnlock()
|
||||
|
||||
if c, ok := n.nodes[mKey.String()]; ok {
|
||||
c <- update
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Error().
|
||||
Err(ctx.Err()).
|
||||
Str("mkey", mKey.String()).
|
||||
Any("origin", ctx.Value("origin")).
|
||||
Any("hostname", ctx.Value("hostname")).
|
||||
Msgf("update not sent, context cancelled")
|
||||
|
||||
return
|
||||
case c <- update:
|
||||
log.Trace().
|
||||
Str("mkey", mKey.String()).
|
||||
Any("origin", ctx.Value("origin")).
|
||||
Any("hostname", ctx.Value("hostname")).
|
||||
Msgf("update successfully sent on chan")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
|
@ -569,7 +570,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
Str("node", node.Hostname).
|
||||
Msg("node already registered, reauthenticating")
|
||||
|
||||
err := h.db.NodeSetExpiry(node, expiry)
|
||||
err := h.db.NodeSetExpiry(node.ID, expiry)
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to refresh node")
|
||||
http.Error(
|
||||
|
@ -623,6 +624,12 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
stateUpdate := types.StateUpdateExpire(node.ID, expiry)
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na")
|
||||
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
|
@ -712,14 +719,22 @@ func (h *Headscale) registerNodeForOIDCCallback(
|
|||
machineKey *key.MachinePublic,
|
||||
expiry time.Time,
|
||||
) error {
|
||||
if _, err := h.db.RegisterNodeFromAuthCallback(
|
||||
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if _, err := db.RegisterNodeFromAuthCallback(
|
||||
// TODO(kradalby): find a better way to use the cache across modules
|
||||
tx,
|
||||
h.registrationCache,
|
||||
*machineKey,
|
||||
user.Name,
|
||||
&expiry,
|
||||
util.RegisterMethodOIDC,
|
||||
h.cfg.IPPrefixes,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
util.LogErr(err, "could not register node")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
|
|
|
@ -250,6 +250,21 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
|
|||
if node.IPAddresses.InIPSet(expanded) {
|
||||
dests = append(dests, dest)
|
||||
}
|
||||
|
||||
// If the node exposes routes, ensure they are note removed
|
||||
// when the filters are reduced.
|
||||
if node.Hostinfo != nil {
|
||||
// TODO(kradalby): Evaluate if we should only keep
|
||||
// the routes if the route is enabled. This will
|
||||
// require database access in this part of the code.
|
||||
if len(node.Hostinfo.RoutableIPs) > 0 {
|
||||
for _, routableIP := range node.Hostinfo.RoutableIPs {
|
||||
if expanded.ContainsPrefix(routableIP) {
|
||||
dests = append(dests, dest)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(dests) > 0 {
|
||||
|
@ -890,8 +905,14 @@ func (pol *ACLPolicy) TagsOfNode(
|
|||
validTags := make([]string, 0)
|
||||
invalidTags := make([]string, 0)
|
||||
|
||||
// TODO(kradalby): Why is this sometimes nil? coming from tailNode?
|
||||
if node == nil {
|
||||
return validTags, invalidTags
|
||||
}
|
||||
|
||||
validTagMap := make(map[string]bool)
|
||||
invalidTagMap := make(map[string]bool)
|
||||
if node.Hostinfo != nil {
|
||||
for _, tag := range node.Hostinfo.RequestTags {
|
||||
owners, err := expandOwnersFromTag(pol, tag)
|
||||
if errors.Is(err, ErrInvalidTag) {
|
||||
|
@ -917,6 +938,7 @@ func (pol *ACLPolicy) TagsOfNode(
|
|||
for tag := range validTagMap {
|
||||
validTags = append(validTags, tag)
|
||||
}
|
||||
}
|
||||
|
||||
return validTags, invalidTags
|
||||
}
|
||||
|
|
|
@ -1901,6 +1901,81 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
},
|
||||
want: []tailcfg.FilterRule{},
|
||||
},
|
||||
{
|
||||
name: "1604-subnet-routers-are-preserved",
|
||||
pol: ACLPolicy{
|
||||
Groups: Groups{
|
||||
"group:admins": {"user1"},
|
||||
},
|
||||
ACLs: []ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:admins"},
|
||||
Destinations: []string{"group:admins:*"},
|
||||
},
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:admins"},
|
||||
Destinations: []string{"10.33.0.0/16:*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
node: &types.Node{
|
||||
IPAddresses: types.NodeAddresses{
|
||||
netip.MustParseAddr("100.64.0.1"),
|
||||
netip.MustParseAddr("fd7a:115c:a1e0::1"),
|
||||
},
|
||||
User: types.User{Name: "user1"},
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.33.0.0/16"),
|
||||
},
|
||||
},
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
IPAddresses: types.NodeAddresses{
|
||||
netip.MustParseAddr("100.64.0.2"),
|
||||
netip.MustParseAddr("fd7a:115c:a1e0::2"),
|
||||
},
|
||||
User: types.User{Name: "user1"},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{
|
||||
"100.64.0.1/32",
|
||||
"100.64.0.2/32",
|
||||
"fd7a:115c:a1e0::1/128",
|
||||
"fd7a:115c:a1e0::2/128",
|
||||
},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "100.64.0.1/32",
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
{
|
||||
IP: "fd7a:115c:a1e0::1/128",
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{
|
||||
"100.64.0.1/32",
|
||||
"100.64.0.2/32",
|
||||
"fd7a:115c:a1e0::1/128",
|
||||
"fd7a:115c:a1e0::2/128",
|
||||
},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "10.33.0.0/16",
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
|
@ -4,12 +4,15 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
|
@ -125,6 +128,18 @@ func (h *Headscale) handlePoll(
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
if h.ACLPolicy != nil {
|
||||
// update routes with peer information
|
||||
update, err := h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
|
||||
if err != nil {
|
||||
logErr(err, "Error running auto approved routes")
|
||||
}
|
||||
|
||||
if update != nil {
|
||||
sendUpdate = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Services is mostly useful for discovery and not critical,
|
||||
|
@ -138,41 +153,68 @@ func (h *Headscale) handlePoll(
|
|||
}
|
||||
|
||||
if sendUpdate {
|
||||
if err := h.db.NodeSave(node); err != nil {
|
||||
if err := h.db.DB.Save(node).Error; err != nil {
|
||||
logErr(err, "Failed to persist/update node in the database")
|
||||
http.Error(writer, "", http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Send an update to all peers to propagate the new routes
|
||||
// available.
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
Message: "called from handlePoll -> update -> new hostinfo",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-hostinfochange", node.Hostname)
|
||||
h.nodeNotifier.NotifyWithIgnore(
|
||||
ctx,
|
||||
stateUpdate,
|
||||
node.MachineKey.String())
|
||||
}
|
||||
|
||||
// Send an update to the node itself with to ensure it
|
||||
// has an updated packetfilter allowing the new route
|
||||
// if it is defined in the ACL.
|
||||
selfUpdate := types.StateUpdate{
|
||||
Type: types.StateSelfUpdate,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
}
|
||||
if selfUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", node.Hostname)
|
||||
h.nodeNotifier.NotifyByMachineKey(
|
||||
ctx,
|
||||
selfUpdate,
|
||||
node.MachineKey)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.db.NodeSave(node); err != nil {
|
||||
if err := h.db.DB.Save(node).Error; err != nil {
|
||||
logErr(err, "Failed to persist/update node in the database")
|
||||
http.Error(writer, "", http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(kradalby): Figure out why patch changes does
|
||||
// not show up in output from `tailscale debug netmap`.
|
||||
// stateUpdate := types.StateUpdate{
|
||||
// Type: types.StatePeerChangedPatch,
|
||||
// ChangePatches: []*tailcfg.PeerChange{&change},
|
||||
// }
|
||||
stateUpdate := types.StateUpdate{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{&change},
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: types.Nodes{node},
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
|
||||
h.nodeNotifier.NotifyWithIgnore(
|
||||
ctx,
|
||||
stateUpdate,
|
||||
node.MachineKey.String())
|
||||
}
|
||||
|
@ -228,7 +270,7 @@ func (h *Headscale) handlePoll(
|
|||
}
|
||||
}
|
||||
|
||||
if err := h.db.NodeSave(node); err != nil {
|
||||
if err := h.db.DB.Save(node).Error; err != nil {
|
||||
logErr(err, "Failed to persist/update node in the database")
|
||||
http.Error(writer, "", http.StatusInternalServerError)
|
||||
|
||||
|
@ -265,7 +307,10 @@ func (h *Headscale) handlePoll(
|
|||
// update ACLRules with peer informations (to update server tags if necessary)
|
||||
if h.ACLPolicy != nil {
|
||||
// update routes with peer information
|
||||
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
|
||||
// This state update is ignored as it will be sent
|
||||
// as part of the whole node
|
||||
// TODO(kradalby): figure out if that is actually correct
|
||||
_, err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
|
||||
if err != nil {
|
||||
logErr(err, "Error running auto approved routes")
|
||||
}
|
||||
|
@ -301,11 +346,17 @@ func (h *Headscale) handlePoll(
|
|||
Message: "called from handlePoll -> new node added",
|
||||
}
|
||||
if stateUpdate.Valid() {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-newnode-peers", node.Hostname)
|
||||
h.nodeNotifier.NotifyWithIgnore(
|
||||
ctx,
|
||||
stateUpdate,
|
||||
node.MachineKey.String())
|
||||
}
|
||||
|
||||
if len(node.Routes) > 0 {
|
||||
go h.pollFailoverRoutes(logErr, "new node", node)
|
||||
}
|
||||
|
||||
// Set up the client stream
|
||||
h.pollNetMapStreamWG.Add(1)
|
||||
defer h.pollNetMapStreamWG.Done()
|
||||
|
@ -323,15 +374,9 @@ func (h *Headscale) handlePoll(
|
|||
|
||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||
|
||||
ctx = context.WithValue(ctx, nodeNameContextKey, node.Hostname)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
ctx, cancel := context.WithCancel(context.WithValue(ctx, nodeNameContextKey, node.Hostname))
|
||||
defer cancel()
|
||||
|
||||
if len(node.Routes) > 0 {
|
||||
go h.db.EnsureFailoverRouteIsAvailable(node)
|
||||
}
|
||||
|
||||
for {
|
||||
logInfo("Waiting for update on stream channel")
|
||||
select {
|
||||
|
@ -370,6 +415,17 @@ func (h *Headscale) handlePoll(
|
|||
var data []byte
|
||||
var err error
|
||||
|
||||
// Ensure the node object is updated, for example, there
|
||||
// might have been a hostinfo update in a sidechannel
|
||||
// which contains data needed to generate a map response.
|
||||
node, err = h.db.GetNodeByMachineKey(node.MachineKey)
|
||||
if err != nil {
|
||||
logErr(err, "Could not get machine from db")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
startMapResp := time.Now()
|
||||
switch update.Type {
|
||||
case types.StateFullUpdate:
|
||||
logInfo("Sending Full MapResponse")
|
||||
|
@ -378,6 +434,7 @@ func (h *Headscale) handlePoll(
|
|||
case types.StatePeerChanged:
|
||||
logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message))
|
||||
|
||||
isConnectedMap := h.nodeNotifier.ConnectedMap()
|
||||
for _, node := range update.ChangeNodes {
|
||||
// If a node is not reported to be online, it might be
|
||||
// because the value is outdated, check with the notifier.
|
||||
|
@ -385,7 +442,7 @@ func (h *Headscale) handlePoll(
|
|||
// this might be because it has announced itself, but not
|
||||
// reached the stage to actually create the notifier channel.
|
||||
if node.IsOnline != nil && !*node.IsOnline {
|
||||
isOnline := h.nodeNotifier.IsConnected(node.MachineKey)
|
||||
isOnline := isConnectedMap[node.MachineKey]
|
||||
node.IsOnline = &isOnline
|
||||
}
|
||||
}
|
||||
|
@ -401,7 +458,7 @@ func (h *Headscale) handlePoll(
|
|||
if len(update.ChangeNodes) == 1 {
|
||||
logInfo("Sending SelfUpdate MapResponse")
|
||||
node = update.ChangeNodes[0]
|
||||
data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy)
|
||||
data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy, types.SelfUpdateIdentifier)
|
||||
} else {
|
||||
logInfo("SelfUpdate contained too many nodes, this is likely a bug in the code, please report.")
|
||||
}
|
||||
|
@ -416,8 +473,11 @@ func (h *Headscale) handlePoll(
|
|||
return
|
||||
}
|
||||
|
||||
log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startMapResp).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished making map response")
|
||||
|
||||
// Only send update if there is change
|
||||
if data != nil {
|
||||
startWrite := time.Now()
|
||||
_, err = writer.Write(data)
|
||||
if err != nil {
|
||||
logErr(err, "Could not write the map response")
|
||||
|
@ -435,6 +495,7 @@ func (h *Headscale) handlePoll(
|
|||
|
||||
return
|
||||
}
|
||||
log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished writing mapresp to node")
|
||||
|
||||
log.Info().
|
||||
Caller().
|
||||
|
@ -454,7 +515,7 @@ func (h *Headscale) handlePoll(
|
|||
go h.updateNodeOnlineStatus(false, node)
|
||||
|
||||
// Failover the node's routes if any.
|
||||
go h.db.FailoverNodeRoutesWithNotify(node)
|
||||
go h.pollFailoverRoutes(logErr, "node closing connection", node)
|
||||
|
||||
// The connection has been closed, so we can stop polling.
|
||||
return
|
||||
|
@ -467,6 +528,22 @@ func (h *Headscale) handlePoll(
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) pollFailoverRoutes(logErr func(error, string), where string, node *types.Node) {
|
||||
update, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return db.EnsureFailoverRouteIsAvailable(tx, h.nodeNotifier.ConnectedMap(), node)
|
||||
})
|
||||
if err != nil {
|
||||
logErr(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if update != nil && !update.Empty() && update.Valid() {
|
||||
ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname)
|
||||
h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.MachineKey.String())
|
||||
}
|
||||
}
|
||||
|
||||
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
|
||||
// about change in their online/offline status.
|
||||
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
|
||||
|
@ -486,10 +563,13 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
|
|||
},
|
||||
}
|
||||
if statusUpdate.Valid() {
|
||||
h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String())
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
|
||||
h.nodeNotifier.NotifyWithIgnore(ctx, statusUpdate, node.MachineKey.String())
|
||||
}
|
||||
|
||||
err := h.db.UpdateLastSeen(node)
|
||||
err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
return db.UpdateLastSeen(tx, node.ID, *node.LastSeen)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Cannot update node LastSeen")
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
MinimumCapVersion tailcfg.CapabilityVersion = 56
|
||||
MinimumCapVersion tailcfg.CapabilityVersion = 58
|
||||
)
|
||||
|
||||
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
|
||||
|
|
|
@ -41,8 +41,12 @@ func (s *Suite) ResetDB(c *check.C) {
|
|||
}
|
||||
cfg := types.Config{
|
||||
NoisePrivateKeyPath: tmpDir + "/noise_private.key",
|
||||
DBtype: "sqlite3",
|
||||
DBpath: tmpDir + "/headscale_test.db",
|
||||
Database: types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
IPPrefixes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
|
|
|
@ -1,15 +1,23 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
const (
|
||||
SelfUpdateIdentifier = "self-update"
|
||||
DatabasePostgres = "postgres"
|
||||
DatabaseSqlite = "sqlite3"
|
||||
)
|
||||
|
||||
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
|
||||
type IPPrefix netip.Prefix
|
||||
|
@ -150,7 +158,9 @@ func (su *StateUpdate) Valid() bool {
|
|||
}
|
||||
case StateSelfUpdate:
|
||||
if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 {
|
||||
panic("Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node")
|
||||
panic(
|
||||
"Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node",
|
||||
)
|
||||
}
|
||||
case StateDERPUpdated:
|
||||
if su.DERPMap == nil {
|
||||
|
@ -160,3 +170,37 @@ func (su *StateUpdate) Valid() bool {
|
|||
|
||||
return true
|
||||
}
|
||||
|
||||
// Empty reports if there are any updates in the StateUpdate.
|
||||
func (su *StateUpdate) Empty() bool {
|
||||
switch su.Type {
|
||||
case StatePeerChanged:
|
||||
return len(su.ChangeNodes) == 0
|
||||
case StatePeerChangedPatch:
|
||||
return len(su.ChangePatches) == 0
|
||||
case StatePeerRemoved:
|
||||
return len(su.Removed) == 0
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate {
|
||||
return StateUpdate{
|
||||
Type: StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: tailcfg.NodeID(nodeID),
|
||||
KeyExpiry: &expiry,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
|
||||
ctx2, _ := context.WithTimeout(
|
||||
context.WithValue(context.WithValue(ctx, "hostname", hostname), "origin", origin),
|
||||
3*time.Second,
|
||||
)
|
||||
return ctx2
|
||||
}
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/prometheus/common/model"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -20,6 +19,8 @@ import (
|
|||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -46,16 +47,9 @@ type Config struct {
|
|||
Log LogConfig
|
||||
DisableUpdateCheck bool
|
||||
|
||||
DERP DERPConfig
|
||||
Database DatabaseConfig
|
||||
|
||||
DBtype string
|
||||
DBpath string
|
||||
DBhost string
|
||||
DBport int
|
||||
DBname string
|
||||
DBuser string
|
||||
DBpass string
|
||||
DBssl string
|
||||
DERP DERPConfig
|
||||
|
||||
TLS TLSConfig
|
||||
|
||||
|
@ -77,6 +71,31 @@ type Config struct {
|
|||
ACL ACLConfig
|
||||
}
|
||||
|
||||
type SqliteConfig struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
type PostgresConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Name string
|
||||
User string
|
||||
Pass string
|
||||
Ssl string
|
||||
MaxOpenConnections int
|
||||
MaxIdleConnections int
|
||||
ConnMaxIdleTimeSecs int
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
// Type sets the database type, either "sqlite3" or "postgres"
|
||||
Type string
|
||||
Debug bool
|
||||
|
||||
Sqlite SqliteConfig
|
||||
Postgres PostgresConfig
|
||||
}
|
||||
|
||||
type TLSConfig struct {
|
||||
CertPath string
|
||||
KeyPath string
|
||||
|
@ -112,6 +131,7 @@ type OIDCConfig struct {
|
|||
|
||||
type DERPConfig struct {
|
||||
ServerEnabled bool
|
||||
AutomaticallyAddEmbeddedDerpRegion bool
|
||||
ServerRegionID int
|
||||
ServerRegionCode string
|
||||
ServerRegionName string
|
||||
|
@ -121,6 +141,8 @@ type DERPConfig struct {
|
|||
Paths []string
|
||||
AutoUpdate bool
|
||||
UpdateFrequency time.Duration
|
||||
IPv4 string
|
||||
IPv6 string
|
||||
}
|
||||
|
||||
type LogTailConfig struct {
|
||||
|
@ -162,6 +184,19 @@ func LoadConfig(path string, isFile bool) error {
|
|||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
viper.AutomaticEnv()
|
||||
|
||||
viper.RegisterAlias("db_type", "database.type")
|
||||
|
||||
// SQLite aliases
|
||||
viper.RegisterAlias("db_path", "database.sqlite.path")
|
||||
|
||||
// Postgres aliases
|
||||
viper.RegisterAlias("db_host", "database.postgres.host")
|
||||
viper.RegisterAlias("db_port", "database.postgres.port")
|
||||
viper.RegisterAlias("db_name", "database.postgres.name")
|
||||
viper.RegisterAlias("db_user", "database.postgres.user")
|
||||
viper.RegisterAlias("db_pass", "database.postgres.pass")
|
||||
viper.RegisterAlias("db_ssl", "database.postgres.ssl")
|
||||
|
||||
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
|
||||
viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType)
|
||||
|
||||
|
@ -173,6 +208,7 @@ func LoadConfig(path string, isFile bool) error {
|
|||
|
||||
viper.SetDefault("derp.server.enabled", false)
|
||||
viper.SetDefault("derp.server.stun.enabled", true)
|
||||
viper.SetDefault("derp.server.automatically_add_embedded_derp_region", true)
|
||||
|
||||
viper.SetDefault("unix_socket", "/var/run/headscale/headscale.sock")
|
||||
viper.SetDefault("unix_socket_permission", "0o770")
|
||||
|
@ -184,6 +220,10 @@ func LoadConfig(path string, isFile bool) error {
|
|||
viper.SetDefault("cli.insecure", false)
|
||||
|
||||
viper.SetDefault("db_ssl", false)
|
||||
viper.SetDefault("database.postgres.ssl", false)
|
||||
viper.SetDefault("database.postgres.max_open_conns", 10)
|
||||
viper.SetDefault("database.postgres.max_idle_conns", 10)
|
||||
viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600)
|
||||
|
||||
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
||||
viper.SetDefault("oidc.strip_email_domain", true)
|
||||
|
@ -294,8 +334,14 @@ func GetDERPConfig() DERPConfig {
|
|||
serverRegionCode := viper.GetString("derp.server.region_code")
|
||||
serverRegionName := viper.GetString("derp.server.region_name")
|
||||
stunAddr := viper.GetString("derp.server.stun_listen_addr")
|
||||
privateKeyPath := util.AbsolutePathFromConfigPath(viper.GetString("derp.server.private_key_path"))
|
||||
|
||||
privateKeyPath := util.AbsolutePathFromConfigPath(
|
||||
viper.GetString("derp.server.private_key_path"),
|
||||
)
|
||||
ipv4 := viper.GetString("derp.server.ipv4")
|
||||
ipv6 := viper.GetString("derp.server.ipv6")
|
||||
automaticallyAddEmbeddedDerpRegion := viper.GetBool(
|
||||
"derp.server.automatically_add_embedded_derp_region",
|
||||
)
|
||||
if serverEnabled && stunAddr == "" {
|
||||
log.Fatal().
|
||||
Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true")
|
||||
|
@ -318,6 +364,11 @@ func GetDERPConfig() DERPConfig {
|
|||
|
||||
paths := viper.GetStringSlice("derp.paths")
|
||||
|
||||
if serverEnabled && !automaticallyAddEmbeddedDerpRegion && len(paths) == 0 {
|
||||
log.Fatal().
|
||||
Msg("Disabling derp.server.automatically_add_embedded_derp_region requires to configure the derp server in derp.paths")
|
||||
}
|
||||
|
||||
autoUpdate := viper.GetBool("derp.auto_update_enabled")
|
||||
updateFrequency := viper.GetDuration("derp.update_frequency")
|
||||
|
||||
|
@ -332,6 +383,9 @@ func GetDERPConfig() DERPConfig {
|
|||
Paths: paths,
|
||||
AutoUpdate: autoUpdate,
|
||||
UpdateFrequency: updateFrequency,
|
||||
IPv4: ipv4,
|
||||
IPv6: ipv6,
|
||||
AutomaticallyAddEmbeddedDerpRegion: automaticallyAddEmbeddedDerpRegion,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -379,6 +433,45 @@ func GetLogConfig() LogConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetDatabaseConfig() DatabaseConfig {
|
||||
debug := viper.GetBool("database.debug")
|
||||
|
||||
type_ := viper.GetString("database.type")
|
||||
|
||||
switch type_ {
|
||||
case DatabaseSqlite, DatabasePostgres:
|
||||
break
|
||||
case "sqlite":
|
||||
type_ = "sqlite3"
|
||||
default:
|
||||
log.Fatal().
|
||||
Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_)
|
||||
}
|
||||
|
||||
return DatabaseConfig{
|
||||
Type: type_,
|
||||
Debug: debug,
|
||||
Sqlite: SqliteConfig{
|
||||
Path: util.AbsolutePathFromConfigPath(
|
||||
viper.GetString("database.sqlite.path"),
|
||||
),
|
||||
},
|
||||
Postgres: PostgresConfig{
|
||||
Host: viper.GetString("database.postgres.host"),
|
||||
Port: viper.GetInt("database.postgres.port"),
|
||||
Name: viper.GetString("database.postgres.name"),
|
||||
User: viper.GetString("database.postgres.user"),
|
||||
Pass: viper.GetString("database.postgres.pass"),
|
||||
Ssl: viper.GetString("database.postgres.ssl"),
|
||||
MaxOpenConnections: viper.GetInt("database.postgres.max_open_conns"),
|
||||
MaxIdleConnections: viper.GetInt("database.postgres.max_idle_conns"),
|
||||
ConnMaxIdleTimeSecs: viper.GetInt(
|
||||
"database.postgres.conn_max_idle_time_secs",
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
|
||||
if viper.IsSet("dns_config") {
|
||||
dnsConfig := &tailcfg.DNSConfig{}
|
||||
|
@ -580,7 +673,7 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oidcClientSecret = string(secretBytes)
|
||||
oidcClientSecret = strings.TrimSpace(string(secretBytes))
|
||||
}
|
||||
|
||||
return &Config{
|
||||
|
@ -607,14 +700,7 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
"node_update_check_interval",
|
||||
),
|
||||
|
||||
DBtype: viper.GetString("db_type"),
|
||||
DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")),
|
||||
DBhost: viper.GetString("db_host"),
|
||||
DBport: viper.GetInt("db_port"),
|
||||
DBname: viper.GetString("db_name"),
|
||||
DBuser: viper.GetString("db_user"),
|
||||
DBpass: viper.GetString("db_pass"),
|
||||
DBssl: viper.GetString("db_ssl"),
|
||||
Database: GetDatabaseConfig(),
|
||||
|
||||
TLS: GetTLSConfig(),
|
||||
|
||||
|
|
|
@ -383,7 +383,7 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri
|
|||
// inform peers about smaller changes to the node.
|
||||
// When a field is added to this function, remember to also add it to:
|
||||
// - node.ApplyPeerChange
|
||||
// - logTracePeerChange in poll.go
|
||||
// - logTracePeerChange in poll.go.
|
||||
func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
|
||||
ret := tailcfg.PeerChange{
|
||||
NodeID: tailcfg.NodeID(node.ID),
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
@ -366,3 +367,110 @@ func TestPeerChangeFromMapRequest(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPeerChange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nodeBefore Node
|
||||
change *tailcfg.PeerChange
|
||||
want Node
|
||||
}{
|
||||
{
|
||||
name: "hostinfo-and-netinfo-not-exists",
|
||||
nodeBefore: Node{},
|
||||
change: &tailcfg.PeerChange{
|
||||
DERPRegion: 1,
|
||||
},
|
||||
want: Node{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
NetInfo: &tailcfg.NetInfo{
|
||||
PreferredDERP: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "hostinfo-netinfo-not-exists",
|
||||
nodeBefore: Node{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "test",
|
||||
},
|
||||
},
|
||||
change: &tailcfg.PeerChange{
|
||||
DERPRegion: 3,
|
||||
},
|
||||
want: Node{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "test",
|
||||
NetInfo: &tailcfg.NetInfo{
|
||||
PreferredDERP: 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "hostinfo-netinfo-exists-derp-set",
|
||||
nodeBefore: Node{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "test",
|
||||
NetInfo: &tailcfg.NetInfo{
|
||||
PreferredDERP: 999,
|
||||
},
|
||||
},
|
||||
},
|
||||
change: &tailcfg.PeerChange{
|
||||
DERPRegion: 2,
|
||||
},
|
||||
want: Node{
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "test",
|
||||
NetInfo: &tailcfg.NetInfo{
|
||||
PreferredDERP: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "endpoints-not-set",
|
||||
nodeBefore: Node{},
|
||||
change: &tailcfg.PeerChange{
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("8.8.8.8:88"),
|
||||
},
|
||||
},
|
||||
want: Node{
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("8.8.8.8:88"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "endpoints-set",
|
||||
nodeBefore: Node{
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("6.6.6.6:66"),
|
||||
},
|
||||
},
|
||||
change: &tailcfg.PeerChange{
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("8.8.8.8:88"),
|
||||
},
|
||||
},
|
||||
want: Node{
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("8.8.8.8:88"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.nodeBefore.ApplyPeerChange(tt.change)
|
||||
|
||||
if diff := cmp.Diff(tt.want, tt.nodeBefore, util.Comparers...); diff != "" {
|
||||
t.Errorf("Patch unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package types
|
|||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
|
@ -25,9 +24,10 @@ func (n *User) TailscaleUser() *tailcfg.User {
|
|||
ID: tailcfg.UserID(n.ID),
|
||||
LoginName: n.Name,
|
||||
DisplayName: n.Name,
|
||||
// TODO(kradalby): See if we can fill in Gravatar here
|
||||
ProfilePicURL: "",
|
||||
Logins: []tailcfg.LoginID{},
|
||||
Created: time.Time{},
|
||||
Created: n.CreatedAt,
|
||||
}
|
||||
|
||||
return &user
|
||||
|
@ -38,6 +38,7 @@ func (n *User) TailscaleLogin() *tailcfg.Login {
|
|||
ID: tailcfg.LoginID(n.ID),
|
||||
LoginName: n.Name,
|
||||
DisplayName: n.Name,
|
||||
// TODO(kradalby): See if we can fill in Gravatar here
|
||||
ProfilePicURL: "",
|
||||
}
|
||||
|
||||
|
|
|
@ -15,6 +15,10 @@ var IPComparer = cmp.Comparer(func(x, y netip.Addr) bool {
|
|||
return x.Compare(y) == 0
|
||||
})
|
||||
|
||||
var AddrPortComparer = cmp.Comparer(func(x, y netip.AddrPort) bool {
|
||||
return x == y
|
||||
})
|
||||
|
||||
var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool {
|
||||
return x.String() == y.String()
|
||||
})
|
||||
|
@ -28,5 +32,5 @@ var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
|
|||
})
|
||||
|
||||
var Comparers []cmp.Option = []cmp.Option{
|
||||
IPComparer, PrefixComparer, MkeyComparer, NkeyComparer, DkeyComparer,
|
||||
IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer,
|
||||
}
|
||||
|
|
|
@ -265,6 +265,8 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
})
|
||||
|
@ -325,6 +327,8 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
})
|
||||
|
|
|
@ -53,6 +53,8 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
})
|
||||
|
@ -90,6 +92,8 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
})
|
||||
|
|
|
@ -1240,7 +1240,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
|||
assert.Contains(t, listAll[4].GetGivenName(), "node-5")
|
||||
|
||||
for idx := 0; idx < 3; idx++ {
|
||||
_, err := headscale.Execute(
|
||||
res, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"nodes",
|
||||
|
@ -1251,6 +1251,8 @@ func TestNodeRenameCommand(t *testing.T) {
|
|||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Contains(t, res, "Node renamed")
|
||||
}
|
||||
|
||||
var listAllAfterRename []v1.Node
|
||||
|
|
|
@ -33,20 +33,23 @@ func TestDERPServerScenario(t *testing.T) {
|
|||
defer scenario.Shutdown()
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
"user1": 10,
|
||||
// "user1": len(MustTestVersions),
|
||||
}
|
||||
|
||||
headscaleConfig := map[string]string{}
|
||||
headscaleConfig["HEADSCALE_DERP_URLS"] = ""
|
||||
headscaleConfig["HEADSCALE_DERP_SERVER_ENABLED"] = "true"
|
||||
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_ID"] = "999"
|
||||
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_CODE"] = "headscale"
|
||||
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP"
|
||||
headscaleConfig["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478"
|
||||
headscaleConfig["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key"
|
||||
headscaleConfig := map[string]string{
|
||||
"HEADSCALE_DERP_URLS": "",
|
||||
"HEADSCALE_DERP_SERVER_ENABLED": "true",
|
||||
"HEADSCALE_DERP_SERVER_REGION_ID": "999",
|
||||
"HEADSCALE_DERP_SERVER_REGION_CODE": "headscale",
|
||||
"HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP",
|
||||
"HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478",
|
||||
"HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key",
|
||||
|
||||
// Envknob for enabling DERP debug logs
|
||||
headscaleConfig["DERP_DEBUG_LOGS"] = "true"
|
||||
headscaleConfig["DERP_PROBER_DEBUG_LOGS"] = "true"
|
||||
"DERP_DEBUG_LOGS": "true",
|
||||
"DERP_PROBER_DEBUG_LOGS": "true",
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
spec,
|
||||
|
@ -67,6 +70,8 @@ func TestDERPServerScenario(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
|
||||
|
|
|
@ -26,12 +26,34 @@ func TestPingAllByIP(t *testing.T) {
|
|||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
|
||||
// TODO(kradalby): it does not look like the user thing works, only second
|
||||
// get created? maybe only when many?
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
"user2": len(MustTestVersions),
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyip"))
|
||||
headscaleConfig := map[string]string{
|
||||
"HEADSCALE_DERP_URLS": "",
|
||||
"HEADSCALE_DERP_SERVER_ENABLED": "true",
|
||||
"HEADSCALE_DERP_SERVER_REGION_ID": "999",
|
||||
"HEADSCALE_DERP_SERVER_REGION_CODE": "headscale",
|
||||
"HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP",
|
||||
"HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478",
|
||||
"HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key",
|
||||
|
||||
// Envknob for enabling DERP debug logs
|
||||
"DERP_DEBUG_LOGS": "true",
|
||||
"DERP_PROBER_DEBUG_LOGS": "true",
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec,
|
||||
[]tsic.Option{},
|
||||
hsic.WithTestName("pingallbyip"),
|
||||
hsic.WithConfigEnv(headscaleConfig),
|
||||
hsic.WithTLS(),
|
||||
hsic.WithHostnameAsServerURL(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
|
@ -43,6 +65,46 @@ func TestPingAllByIP(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
})
|
||||
|
||||
success := pingAllHelper(t, allClients, allAddrs)
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
}
|
||||
|
||||
func TestPingAllByIPPublicDERP(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
scenario, err := NewScenario()
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
"user2": len(MustTestVersions),
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec,
|
||||
[]tsic.Option{},
|
||||
hsic.WithTestName("pingallbyippubderp"),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
})
|
||||
|
@ -73,6 +135,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
||||
for _, client := range allClients {
|
||||
ips, err := client.IPs()
|
||||
|
@ -112,6 +176,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
allClients, err = scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
||||
|
@ -263,6 +329,8 @@ func TestPingAllByHostname(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
|
||||
|
@ -320,9 +388,13 @@ func TestTaildrop(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
}
|
||||
curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"}
|
||||
curlCommand := []string{
|
||||
"curl",
|
||||
"--unix-socket",
|
||||
"/var/run/tailscale/tailscaled.sock",
|
||||
"http://local-tailscaled.sock/localapi/v0/file-targets",
|
||||
}
|
||||
err = retry(10, 1*time.Second, func() error {
|
||||
result, _, err := client.Execute(curlCommand)
|
||||
if err != nil {
|
||||
|
@ -339,13 +411,23 @@ func TestTaildrop(t *testing.T) {
|
|||
for _, ft := range fts {
|
||||
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
|
||||
}
|
||||
return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr)
|
||||
return fmt.Errorf(
|
||||
"client %s does not have all its peers as FileTargets, got %d, want: %d\n%s",
|
||||
client.Hostname(),
|
||||
len(fts),
|
||||
len(allClients)-1,
|
||||
ftStr,
|
||||
)
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err)
|
||||
t.Errorf(
|
||||
"failed to query localapi for filetarget on %s, err: %s",
|
||||
client.Hostname(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -457,6 +539,8 @@ func TestResolveMagicDNS(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
// Poor mans cache
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
|
@ -525,6 +609,8 @@ func TestExpireNode(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
})
|
||||
|
@ -546,7 +632,7 @@ func TestExpireNode(t *testing.T) {
|
|||
// TODO(kradalby): This is Headscale specific and would not play nicely
|
||||
// with other implementations of the ControlServer interface
|
||||
result, err := headscale.Execute([]string{
|
||||
"headscale", "nodes", "expire", "--identifier", "0", "--output", "json",
|
||||
"headscale", "nodes", "expire", "--identifier", "1", "--output", "json",
|
||||
})
|
||||
assertNoErr(t, err)
|
||||
|
||||
|
@ -577,16 +663,38 @@ func TestExpireNode(t *testing.T) {
|
|||
assertNotNil(t, peerStatus.Expired)
|
||||
assert.NotNil(t, peerStatus.KeyExpiry)
|
||||
|
||||
t.Logf("node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry)
|
||||
t.Logf(
|
||||
"node %q should have a key expire before %s, was %s",
|
||||
peerStatus.HostName,
|
||||
now.String(),
|
||||
peerStatus.KeyExpiry,
|
||||
)
|
||||
if peerStatus.KeyExpiry != nil {
|
||||
assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry)
|
||||
assert.Truef(
|
||||
t,
|
||||
peerStatus.KeyExpiry.Before(now),
|
||||
"node %q should have a key expire before %s, was %s",
|
||||
peerStatus.HostName,
|
||||
now.String(),
|
||||
peerStatus.KeyExpiry,
|
||||
)
|
||||
}
|
||||
|
||||
assert.Truef(t, peerStatus.Expired, "node %q should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired)
|
||||
assert.Truef(
|
||||
t,
|
||||
peerStatus.Expired,
|
||||
"node %q should be expired, expired is %v",
|
||||
peerStatus.HostName,
|
||||
peerStatus.Expired,
|
||||
)
|
||||
|
||||
_, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()})
|
||||
if !strings.Contains(stderr, "node key has expired") {
|
||||
t.Errorf("expected to be unable to ping expired host %q from %q", node.GetName(), client.Hostname())
|
||||
t.Errorf(
|
||||
"expected to be unable to ping expired host %q from %q",
|
||||
node.GetName(),
|
||||
client.Hostname(),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey)
|
||||
|
@ -598,7 +706,7 @@ func TestExpireNode(t *testing.T) {
|
|||
|
||||
// NeedsLogin means that the node has understood that it is no longer
|
||||
// valid.
|
||||
assert.Equal(t, "NeedsLogin", status.BackendState)
|
||||
assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -627,6 +735,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
|||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
assertClientsState(t, allClients)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
})
|
||||
|
@ -691,7 +801,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
|||
assert.Truef(
|
||||
t,
|
||||
lastSeen.After(lastSeenThreshold),
|
||||
"lastSeen (%v) was not %s after the threshold (%v)",
|
||||
"node (%s) lastSeen (%v) was not %s after the threshold (%v)",
|
||||
node.GetName(),
|
||||
lastSeen,
|
||||
keepAliveInterval,
|
||||
lastSeenThreshold,
|
||||
|
|
|
@ -110,8 +110,8 @@ func DefaultConfigEnv() map[string]string {
|
|||
return map[string]string{
|
||||
"HEADSCALE_LOG_LEVEL": "trace",
|
||||
"HEADSCALE_ACL_POLICY_PATH": "",
|
||||
"HEADSCALE_DB_TYPE": "sqlite3",
|
||||
"HEADSCALE_DB_PATH": "/tmp/integration_test_db.sqlite3",
|
||||
"HEADSCALE_DATABASE_TYPE": "sqlite",
|
||||
"HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3",
|
||||
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
|
||||
"HEADSCALE_NODE_UPDATE_CHECK_INTERVAL": "10s",
|
||||
"HEADSCALE_IP_PREFIXES": "fd7a:115c:a1e0::/48 100.64.0.0/10",
|
||||
|
|
|
@ -9,10 +9,15 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
// This test is both testing the routes command and the propagation of
|
||||
|
@ -83,9 +88,9 @@ func TestEnablingRoutes(t *testing.T) {
|
|||
assert.Len(t, routes, 3)
|
||||
|
||||
for _, route := range routes {
|
||||
assert.Equal(t, route.GetAdvertised(), true)
|
||||
assert.Equal(t, route.GetEnabled(), false)
|
||||
assert.Equal(t, route.GetIsPrimary(), false)
|
||||
assert.Equal(t, true, route.GetAdvertised())
|
||||
assert.Equal(t, false, route.GetEnabled())
|
||||
assert.Equal(t, false, route.GetIsPrimary())
|
||||
}
|
||||
|
||||
// Verify that no routes has been sent to the client,
|
||||
|
@ -130,9 +135,9 @@ func TestEnablingRoutes(t *testing.T) {
|
|||
assert.Len(t, enablingRoutes, 3)
|
||||
|
||||
for _, route := range enablingRoutes {
|
||||
assert.Equal(t, route.GetAdvertised(), true)
|
||||
assert.Equal(t, route.GetEnabled(), true)
|
||||
assert.Equal(t, route.GetIsPrimary(), true)
|
||||
assert.Equal(t, true, route.GetAdvertised())
|
||||
assert.Equal(t, true, route.GetEnabled())
|
||||
assert.Equal(t, true, route.GetIsPrimary())
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
@ -186,6 +191,8 @@ func TestEnablingRoutes(t *testing.T) {
|
|||
})
|
||||
assertNoErr(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
var disablingRoutes []*v1.Route
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
|
@ -204,16 +211,14 @@ func TestEnablingRoutes(t *testing.T) {
|
|||
assert.Equal(t, true, route.GetAdvertised())
|
||||
|
||||
if route.GetId() == routeToBeDisabled.GetId() {
|
||||
assert.Equal(t, route.GetEnabled(), false)
|
||||
assert.Equal(t, route.GetIsPrimary(), false)
|
||||
assert.Equal(t, false, route.GetEnabled())
|
||||
assert.Equal(t, false, route.GetIsPrimary())
|
||||
} else {
|
||||
assert.Equal(t, route.GetEnabled(), true)
|
||||
assert.Equal(t, route.GetIsPrimary(), true)
|
||||
assert.Equal(t, true, route.GetEnabled())
|
||||
assert.Equal(t, true, route.GetIsPrimary())
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Verify that the clients can see the new routes
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
|
@ -289,7 +294,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
// advertise HA route on node 1 and 2
|
||||
// ID 1 will be primary
|
||||
// ID 2 will be secondary
|
||||
for _, client := range allClients {
|
||||
for _, client := range allClients[:2] {
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
|
||||
|
@ -301,6 +306,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
}
|
||||
_, _, err = client.Execute(command)
|
||||
assertNoErrf(t, "failed to advertise route: %s", err)
|
||||
} else {
|
||||
t.Fatalf("failed to find route for Node %s (id: %s)", status.Self.HostName, status.Self.ID)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -323,6 +330,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assertNoErr(t, err)
|
||||
assert.Len(t, routes, 2)
|
||||
|
||||
t.Logf("initial routes %#v", routes)
|
||||
|
||||
for _, route := range routes {
|
||||
assert.Equal(t, true, route.GetAdvertised())
|
||||
assert.Equal(t, false, route.GetEnabled())
|
||||
|
@ -639,6 +648,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assertNoErr(t, err)
|
||||
assert.Len(t, routesAfterDisabling1, 2)
|
||||
|
||||
t.Logf("routes after disabling1 %#v", routesAfterDisabling1)
|
||||
|
||||
// Node 1 is not primary
|
||||
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
|
||||
assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled())
|
||||
|
@ -778,3 +789,413 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableDisableAutoApprovedRoute(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
expectedRoutes := "172.0.0.0/24"
|
||||
|
||||
user := "enable-disable-routing"
|
||||
|
||||
scenario, err := NewScenario()
|
||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||
defer scenario.Shutdown()
|
||||
|
||||
spec := map[string]int{
|
||||
user: 1,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:approve"})}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy(
|
||||
&policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
TagOwners: map[string][]string{
|
||||
"tag:approve": {user},
|
||||
},
|
||||
AutoApprovers: policy.AutoApprovers{
|
||||
Routes: map[string][]string{
|
||||
expectedRoutes: {"tag:approve"},
|
||||
},
|
||||
},
|
||||
},
|
||||
))
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
|
||||
subRouter1 := allClients[0]
|
||||
|
||||
// Initially advertise route
|
||||
command := []string{
|
||||
"tailscale",
|
||||
"set",
|
||||
"--advertise-routes=" + expectedRoutes,
|
||||
}
|
||||
_, _, err = subRouter1.Execute(command)
|
||||
assertNoErrf(t, "failed to advertise route: %s", err)
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
var routes []*v1.Route
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"routes",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&routes,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, routes, 1)
|
||||
|
||||
// All routes should be auto approved and enabled
|
||||
assert.Equal(t, true, routes[0].GetAdvertised())
|
||||
assert.Equal(t, true, routes[0].GetEnabled())
|
||||
assert.Equal(t, true, routes[0].GetIsPrimary())
|
||||
|
||||
// Stop advertising route
|
||||
command = []string{
|
||||
"tailscale",
|
||||
"set",
|
||||
"--advertise-routes=",
|
||||
}
|
||||
_, _, err = subRouter1.Execute(command)
|
||||
assertNoErrf(t, "failed to remove advertised route: %s", err)
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
var notAdvertisedRoutes []*v1.Route
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"routes",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
¬AdvertisedRoutes,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, notAdvertisedRoutes, 1)
|
||||
|
||||
// Route is no longer advertised
|
||||
assert.Equal(t, false, notAdvertisedRoutes[0].GetAdvertised())
|
||||
assert.Equal(t, false, notAdvertisedRoutes[0].GetEnabled())
|
||||
assert.Equal(t, true, notAdvertisedRoutes[0].GetIsPrimary())
|
||||
|
||||
// Advertise route again
|
||||
command = []string{
|
||||
"tailscale",
|
||||
"set",
|
||||
"--advertise-routes=" + expectedRoutes,
|
||||
}
|
||||
_, _, err = subRouter1.Execute(command)
|
||||
assertNoErrf(t, "failed to advertise route: %s", err)
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
var reAdvertisedRoutes []*v1.Route
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"routes",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&reAdvertisedRoutes,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, reAdvertisedRoutes, 1)
|
||||
|
||||
// All routes should be auto approved and enabled
|
||||
assert.Equal(t, true, reAdvertisedRoutes[0].GetAdvertised())
|
||||
assert.Equal(t, true, reAdvertisedRoutes[0].GetEnabled())
|
||||
assert.Equal(t, true, reAdvertisedRoutes[0].GetIsPrimary())
|
||||
}
|
||||
|
||||
// TestSubnetRouteACL verifies that Subnet routes are distributed
|
||||
// as expected when ACLs are activated.
|
||||
// It implements the issue from
|
||||
// https://github.com/juanfont/headscale/issues/1604
|
||||
func TestSubnetRouteACL(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
user := "subnet-route-acl"
|
||||
|
||||
scenario, err := NewScenario()
|
||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||
defer scenario.Shutdown()
|
||||
|
||||
spec := map[string]int{
|
||||
user: 2,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy(
|
||||
&policy.ACLPolicy{
|
||||
Groups: policy.Groups{
|
||||
"group:admins": {user},
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:admins"},
|
||||
Destinations: []string{"group:admins:*"},
|
||||
},
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:admins"},
|
||||
Destinations: []string{"10.33.0.0/16:*"},
|
||||
},
|
||||
// {
|
||||
// Action: "accept",
|
||||
// Sources: []string{"group:admins"},
|
||||
// Destinations: []string{"0.0.0.0/0:*"},
|
||||
// },
|
||||
},
|
||||
},
|
||||
))
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
|
||||
expectedRoutes := map[string]string{
|
||||
"1": "10.33.0.0/16",
|
||||
}
|
||||
|
||||
// Sort nodes by ID
|
||||
sort.SliceStable(allClients, func(i, j int) bool {
|
||||
statusI, err := allClients[i].Status()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
statusJ, err := allClients[j].Status()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return statusI.Self.ID < statusJ.Self.ID
|
||||
})
|
||||
|
||||
subRouter1 := allClients[0]
|
||||
|
||||
client := allClients[1]
|
||||
|
||||
// advertise HA route on node 1 and 2
|
||||
// ID 1 will be primary
|
||||
// ID 2 will be secondary
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
|
||||
if route, ok := expectedRoutes[string(status.Self.ID)]; ok {
|
||||
command := []string{
|
||||
"tailscale",
|
||||
"set",
|
||||
"--advertise-routes=" + route,
|
||||
}
|
||||
_, _, err = client.Execute(command)
|
||||
assertNoErrf(t, "failed to advertise route: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
var routes []*v1.Route
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"routes",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&routes,
|
||||
)
|
||||
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, routes, 1)
|
||||
|
||||
for _, route := range routes {
|
||||
assert.Equal(t, true, route.GetAdvertised())
|
||||
assert.Equal(t, false, route.GetEnabled())
|
||||
assert.Equal(t, false, route.GetIsPrimary())
|
||||
}
|
||||
|
||||
// Verify that no routes has been sent to the client,
|
||||
// they are not yet enabled.
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
assert.Nil(t, peerStatus.PrimaryRoutes)
|
||||
}
|
||||
}
|
||||
|
||||
// Enable all routes
|
||||
for _, route := range routes {
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"routes",
|
||||
"enable",
|
||||
"--route",
|
||||
strconv.Itoa(int(route.GetId())),
|
||||
})
|
||||
assertNoErr(t, err)
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
var enablingRoutes []*v1.Route
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"routes",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&enablingRoutes,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, enablingRoutes, 1)
|
||||
|
||||
// Node 1 has active route
|
||||
assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
|
||||
assert.Equal(t, true, enablingRoutes[0].GetEnabled())
|
||||
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary())
|
||||
|
||||
// Verify that the client has routes from the primary machine
|
||||
srs1, _ := subRouter1.Status()
|
||||
|
||||
clientStatus, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
|
||||
srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
|
||||
|
||||
assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
|
||||
|
||||
t.Logf("subnet1 has following routes: %v", srs1PeerStatus.PrimaryRoutes.AsSlice())
|
||||
assert.Len(t, srs1PeerStatus.PrimaryRoutes.AsSlice(), 1)
|
||||
assert.Contains(
|
||||
t,
|
||||
srs1PeerStatus.PrimaryRoutes.AsSlice(),
|
||||
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
|
||||
)
|
||||
|
||||
clientNm, err := client.Netmap()
|
||||
assertNoErr(t, err)
|
||||
|
||||
wantClientFilter := []filter.Match{
|
||||
{
|
||||
IPProto: []ipproto.Proto{
|
||||
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
|
||||
},
|
||||
Srcs: []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.0.1/32"),
|
||||
netip.MustParsePrefix("100.64.0.2/32"),
|
||||
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||
},
|
||||
Dsts: []filter.NetPortRange{
|
||||
{
|
||||
Net: netip.MustParsePrefix("100.64.0.2/32"),
|
||||
Ports: filter.PortRange{0, 0xffff},
|
||||
},
|
||||
{
|
||||
Net: netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||
Ports: filter.PortRange{0, 0xffff},
|
||||
},
|
||||
},
|
||||
Caps: []filter.CapMatch{},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.PrefixComparer); diff != "" {
|
||||
t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)
|
||||
}
|
||||
|
||||
subnetNm, err := subRouter1.Netmap()
|
||||
assertNoErr(t, err)
|
||||
|
||||
wantSubnetFilter := []filter.Match{
|
||||
{
|
||||
IPProto: []ipproto.Proto{
|
||||
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
|
||||
},
|
||||
Srcs: []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.0.1/32"),
|
||||
netip.MustParsePrefix("100.64.0.2/32"),
|
||||
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||
},
|
||||
Dsts: []filter.NetPortRange{
|
||||
{
|
||||
Net: netip.MustParsePrefix("100.64.0.1/32"),
|
||||
Ports: filter.PortRange{0, 0xffff},
|
||||
},
|
||||
{
|
||||
Net: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||
Ports: filter.PortRange{0, 0xffff},
|
||||
},
|
||||
},
|
||||
Caps: []filter.CapMatch{},
|
||||
},
|
||||
{
|
||||
IPProto: []ipproto.Proto{
|
||||
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
|
||||
},
|
||||
Srcs: []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.0.1/32"),
|
||||
netip.MustParsePrefix("100.64.0.2/32"),
|
||||
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||
},
|
||||
Dsts: []filter.NetPortRange{
|
||||
{
|
||||
Net: netip.MustParsePrefix("10.33.0.0/16"),
|
||||
Ports: filter.PortRange{0, 0xffff},
|
||||
},
|
||||
},
|
||||
Caps: []filter.CapMatch{},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.PrefixComparer); diff != "" {
|
||||
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,8 +56,8 @@ var (
|
|||
"1.44": true, // CapVer: 63
|
||||
"1.42": true, // CapVer: 61
|
||||
"1.40": true, // CapVer: 61
|
||||
"1.38": true, // CapVer: 58
|
||||
"1.36": true, // Oldest supported version, CapVer: 56
|
||||
"1.38": true, // Oldest supported version, CapVer: 58
|
||||
"1.36": false, // CapVer: 56
|
||||
"1.34": false, // CapVer: 51
|
||||
"1.32": false, // CapVer: 46
|
||||
"1.30": false,
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/net/netcheck"
|
||||
"tailscale.com/types/netmap"
|
||||
)
|
||||
|
||||
// nolint
|
||||
|
@ -26,6 +28,8 @@ type TailscaleClient interface {
|
|||
IPs() ([]netip.Addr, error)
|
||||
FQDN() (string, error)
|
||||
Status() (*ipnstate.Status, error)
|
||||
Netmap() (*netmap.NetworkMap, error)
|
||||
Netcheck() (*netcheck.Report, error)
|
||||
WaitForNeedsLogin() error
|
||||
WaitForRunning() error
|
||||
WaitForPeers(expected int) error
|
||||
|
|
|
@ -17,6 +17,8 @@ import (
|
|||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/net/netcheck"
|
||||
"tailscale.com/types/netmap"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -519,6 +521,53 @@ func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) {
|
|||
return &status, err
|
||||
}
|
||||
|
||||
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
|
||||
// Only works with Tailscale 1.56.1 and newer.
|
||||
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
|
||||
command := []string{
|
||||
"tailscale",
|
||||
"debug",
|
||||
"netmap",
|
||||
}
|
||||
|
||||
result, stderr, err := t.Execute(command)
|
||||
if err != nil {
|
||||
fmt.Printf("stderr: %s\n", stderr)
|
||||
return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err)
|
||||
}
|
||||
|
||||
var nm netmap.NetworkMap
|
||||
err = json.Unmarshal([]byte(result), &nm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err)
|
||||
}
|
||||
|
||||
return &nm, err
|
||||
}
|
||||
|
||||
// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance.
|
||||
func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) {
|
||||
command := []string{
|
||||
"tailscale",
|
||||
"netcheck",
|
||||
"--format=json",
|
||||
}
|
||||
|
||||
result, stderr, err := t.Execute(command)
|
||||
if err != nil {
|
||||
fmt.Printf("stderr: %s\n", stderr)
|
||||
return nil, fmt.Errorf("failed to execute tailscale debug netcheck command: %w", err)
|
||||
}
|
||||
|
||||
var nm netcheck.Report
|
||||
err = json.Unmarshal([]byte(result), &nm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal tailscale netcheck: %w", err)
|
||||
}
|
||||
|
||||
return &nm, err
|
||||
}
|
||||
|
||||
// FQDN returns the FQDN as a string of the Tailscale instance.
|
||||
func (t *TailscaleInContainer) FQDN() (string, error) {
|
||||
if t.fqdn != "" {
|
||||
|
@ -623,12 +672,22 @@ func (t *TailscaleInContainer) WaitForPeers(expected int) error {
|
|||
len(peers),
|
||||
)
|
||||
} else {
|
||||
// Verify that the peers of a given node is Online
|
||||
// has a hostname and a DERP relay.
|
||||
for _, peerKey := range peers {
|
||||
peer := status.Peer[peerKey]
|
||||
|
||||
if !peer.Online {
|
||||
return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)
|
||||
}
|
||||
|
||||
if peer.HostName == "" {
|
||||
return fmt.Errorf("[%s] peer count correct, but %s does not have a Hostname", t.hostname, peer.HostName)
|
||||
}
|
||||
|
||||
if peer.Relay == "" {
|
||||
return fmt.Errorf("[%s] peer count correct, but %s does not have a DERP", t.hostname, peer.HostName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"tailscale.com/util/cmpver"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -83,7 +85,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts
|
|||
for _, addr := range addrs {
|
||||
err := client.Ping(addr, opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
|
||||
t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
|
||||
} else {
|
||||
success++
|
||||
}
|
||||
|
@ -120,6 +122,148 @@ func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string)
|
|||
return success
|
||||
}
|
||||
|
||||
// assertClientsState validates the status and netmap of a list of
|
||||
// clients for the general case of all to all connectivity.
|
||||
func assertClientsState(t *testing.T, clients []TailscaleClient) {
|
||||
t.Helper()
|
||||
|
||||
for _, client := range clients {
|
||||
assertValidStatus(t, client)
|
||||
assertValidNetmap(t, client)
|
||||
assertValidNetcheck(t, client)
|
||||
}
|
||||
}
|
||||
|
||||
// assertValidNetmap asserts that the netmap of a client has all
|
||||
// the minimum required fields set to a known working config for
|
||||
// the general case. Fields are checked on self, then all peers.
|
||||
// This test is not suitable for ACL/partial connection tests.
|
||||
// This test can only be run on clients from 1.56.1. It will
|
||||
// automatically pass all clients below that and is safe to call
|
||||
// for all versions.
|
||||
func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
||||
t.Helper()
|
||||
|
||||
if cmpver.Compare("1.56.1", client.Version()) <= 0 ||
|
||||
!strings.Contains(client.Hostname(), "unstable") ||
|
||||
!strings.Contains(client.Hostname(), "head") {
|
||||
return
|
||||
}
|
||||
|
||||
netmap, err := client.Netmap()
|
||||
if err != nil {
|
||||
t.Fatalf("getting netmap for %q: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
|
||||
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
|
||||
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
|
||||
|
||||
assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname())
|
||||
|
||||
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
|
||||
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
|
||||
assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
|
||||
|
||||
for _, peer := range netmap.Peers {
|
||||
assert.NotEqualf(t, "127.3.3.40:0", peer.DERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.DERP())
|
||||
|
||||
assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
|
||||
if hi := peer.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
|
||||
|
||||
// Netinfo is not always set
|
||||
assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
|
||||
if ni := hi.NetInfo(); ni.Valid() {
|
||||
assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
|
||||
assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
|
||||
assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
|
||||
|
||||
assert.Truef(t, *peer.Online(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
|
||||
|
||||
assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
|
||||
assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
|
||||
assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
// assertValidStatus asserts that the status of a client has all
|
||||
// the minimum required fields set to a known working config for
|
||||
// the general case. Fields are checked on self, then all peers.
|
||||
// This test is not suitable for ACL/partial connection tests.
|
||||
func assertValidStatus(t *testing.T, client TailscaleClient) {
|
||||
t.Helper()
|
||||
status, err := client.Status()
|
||||
if err != nil {
|
||||
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname())
|
||||
assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname())
|
||||
assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname())
|
||||
|
||||
assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname())
|
||||
|
||||
// This seem to not appear until version 1.56
|
||||
if status.Self.AllowedIPs != nil {
|
||||
assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname())
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname())
|
||||
|
||||
assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname())
|
||||
|
||||
assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname())
|
||||
|
||||
// This isnt really relevant for Self as it wont be in its own socket/wireguard.
|
||||
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
|
||||
// assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname())
|
||||
|
||||
for _, peer := range status.Peer {
|
||||
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
|
||||
assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname())
|
||||
assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname())
|
||||
|
||||
assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname())
|
||||
|
||||
// This seem to not appear until version 1.56
|
||||
if peer.AllowedIPs != nil {
|
||||
assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname())
|
||||
}
|
||||
|
||||
// Addrs does not seem to appear in the status from peers.
|
||||
// assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname())
|
||||
|
||||
assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname())
|
||||
|
||||
assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname())
|
||||
assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname())
|
||||
|
||||
// TODO(kradalby): InEngine is only true when a proper tunnel is set up,
|
||||
// there might be some interesting stuff to test here in the future.
|
||||
// assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
func assertValidNetcheck(t *testing.T, client TailscaleClient) {
|
||||
t.Helper()
|
||||
report, err := client.Netcheck()
|
||||
if err != nil {
|
||||
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname())
|
||||
}
|
||||
|
||||
func isSelfClient(client TailscaleClient, addr string) bool {
|
||||
if addr == client.Hostname() {
|
||||
return true
|
||||
|
@ -152,7 +296,7 @@ func isCI() bool {
|
|||
}
|
||||
|
||||
func dockertestMaxWait() time.Duration {
|
||||
wait := 60 * time.Second //nolint
|
||||
wait := 120 * time.Second //nolint
|
||||
|
||||
if isCI() {
|
||||
wait = 300 * time.Second //nolint
|
||||
|
|
Loading…
Reference in a new issue