Merge upstream/main into fork
This commit is contained in:
commit
cc430af3f9
52 changed files with 3272 additions and 1415 deletions
13
.github/ISSUE_TEMPLATE/bug_report.md
vendored
13
.github/ISSUE_TEMPLATE/bug_report.md
vendored
|
@ -50,3 +50,16 @@ instead of filing a bug report.
|
||||||
## To Reproduce
|
## To Reproduce
|
||||||
|
|
||||||
<!-- Steps to reproduce the behavior. -->
|
<!-- Steps to reproduce the behavior. -->
|
||||||
|
|
||||||
|
## Logs and attachments
|
||||||
|
|
||||||
|
<!-- Please attach files with:
|
||||||
|
- Client netmap dump (see below)
|
||||||
|
- ACL configuration
|
||||||
|
- Headscale configuration
|
||||||
|
|
||||||
|
Dump the netmap of tailscale clients:
|
||||||
|
`tailscale debug netmap > DESCRIPTIVE_NAME.json`
|
||||||
|
|
||||||
|
Please provide information describing the netmap, which client, which headscale version etc.
|
||||||
|
-->
|
||||||
|
|
67
.github/workflows/test-integration-v2-TestEnableDisableAutoApprovedRoute.yaml
vendored
Normal file
67
.github/workflows/test-integration-v2-TestEnableDisableAutoApprovedRoute.yaml
vendored
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
|
||||||
|
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
|
||||||
|
|
||||||
|
name: Integration Test v2 - TestEnableDisableAutoApprovedRoute
|
||||||
|
|
||||||
|
on: [pull_request]
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
TestEnableDisableAutoApprovedRoute:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
with:
|
||||||
|
fetch-depth: 2
|
||||||
|
|
||||||
|
- uses: DeterminateSystems/nix-installer-action@main
|
||||||
|
- uses: DeterminateSystems/magic-nix-cache-action@main
|
||||||
|
- uses: satackey/action-docker-layer-caching@main
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
- name: Get changed files
|
||||||
|
id: changed-files
|
||||||
|
uses: tj-actions/changed-files@v34
|
||||||
|
with:
|
||||||
|
files: |
|
||||||
|
*.nix
|
||||||
|
go.*
|
||||||
|
**/*.go
|
||||||
|
integration_test/
|
||||||
|
config-example.yaml
|
||||||
|
|
||||||
|
- name: Run TestEnableDisableAutoApprovedRoute
|
||||||
|
uses: Wandalen/wretry.action@master
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
with:
|
||||||
|
attempt_limit: 5
|
||||||
|
command: |
|
||||||
|
nix develop --command -- docker run \
|
||||||
|
--tty --rm \
|
||||||
|
--volume ~/.cache/hs-integration-go:/go \
|
||||||
|
--name headscale-test-suite \
|
||||||
|
--volume $PWD:$PWD -w $PWD/integration \
|
||||||
|
--volume /var/run/docker.sock:/var/run/docker.sock \
|
||||||
|
--volume $PWD/control_logs:/tmp/control \
|
||||||
|
golang:1 \
|
||||||
|
go run gotest.tools/gotestsum@latest -- ./... \
|
||||||
|
-failfast \
|
||||||
|
-timeout 120m \
|
||||||
|
-parallel 1 \
|
||||||
|
-run "^TestEnableDisableAutoApprovedRoute$"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
with:
|
||||||
|
name: logs
|
||||||
|
path: "control_logs/*.log"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
with:
|
||||||
|
name: pprof
|
||||||
|
path: "control_logs/*.pprof.tar"
|
67
.github/workflows/test-integration-v2-TestPingAllByIPPublicDERP.yaml
vendored
Normal file
67
.github/workflows/test-integration-v2-TestPingAllByIPPublicDERP.yaml
vendored
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
|
||||||
|
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
|
||||||
|
|
||||||
|
name: Integration Test v2 - TestPingAllByIPPublicDERP
|
||||||
|
|
||||||
|
on: [pull_request]
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
TestPingAllByIPPublicDERP:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
with:
|
||||||
|
fetch-depth: 2
|
||||||
|
|
||||||
|
- uses: DeterminateSystems/nix-installer-action@main
|
||||||
|
- uses: DeterminateSystems/magic-nix-cache-action@main
|
||||||
|
- uses: satackey/action-docker-layer-caching@main
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
- name: Get changed files
|
||||||
|
id: changed-files
|
||||||
|
uses: tj-actions/changed-files@v34
|
||||||
|
with:
|
||||||
|
files: |
|
||||||
|
*.nix
|
||||||
|
go.*
|
||||||
|
**/*.go
|
||||||
|
integration_test/
|
||||||
|
config-example.yaml
|
||||||
|
|
||||||
|
- name: Run TestPingAllByIPPublicDERP
|
||||||
|
uses: Wandalen/wretry.action@master
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
with:
|
||||||
|
attempt_limit: 5
|
||||||
|
command: |
|
||||||
|
nix develop --command -- docker run \
|
||||||
|
--tty --rm \
|
||||||
|
--volume ~/.cache/hs-integration-go:/go \
|
||||||
|
--name headscale-test-suite \
|
||||||
|
--volume $PWD:$PWD -w $PWD/integration \
|
||||||
|
--volume /var/run/docker.sock:/var/run/docker.sock \
|
||||||
|
--volume $PWD/control_logs:/tmp/control \
|
||||||
|
golang:1 \
|
||||||
|
go run gotest.tools/gotestsum@latest -- ./... \
|
||||||
|
-failfast \
|
||||||
|
-timeout 120m \
|
||||||
|
-parallel 1 \
|
||||||
|
-run "^TestPingAllByIPPublicDERP$"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
with:
|
||||||
|
name: logs
|
||||||
|
path: "control_logs/*.log"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
with:
|
||||||
|
name: pprof
|
||||||
|
path: "control_logs/*.pprof.tar"
|
67
.github/workflows/test-integration-v2-TestSubnetRouteACL.yaml
vendored
Normal file
67
.github/workflows/test-integration-v2-TestSubnetRouteACL.yaml
vendored
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
|
||||||
|
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
|
||||||
|
|
||||||
|
name: Integration Test v2 - TestSubnetRouteACL
|
||||||
|
|
||||||
|
on: [pull_request]
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
TestSubnetRouteACL:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
with:
|
||||||
|
fetch-depth: 2
|
||||||
|
|
||||||
|
- uses: DeterminateSystems/nix-installer-action@main
|
||||||
|
- uses: DeterminateSystems/magic-nix-cache-action@main
|
||||||
|
- uses: satackey/action-docker-layer-caching@main
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
- name: Get changed files
|
||||||
|
id: changed-files
|
||||||
|
uses: tj-actions/changed-files@v34
|
||||||
|
with:
|
||||||
|
files: |
|
||||||
|
*.nix
|
||||||
|
go.*
|
||||||
|
**/*.go
|
||||||
|
integration_test/
|
||||||
|
config-example.yaml
|
||||||
|
|
||||||
|
- name: Run TestSubnetRouteACL
|
||||||
|
uses: Wandalen/wretry.action@master
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
with:
|
||||||
|
attempt_limit: 5
|
||||||
|
command: |
|
||||||
|
nix develop --command -- docker run \
|
||||||
|
--tty --rm \
|
||||||
|
--volume ~/.cache/hs-integration-go:/go \
|
||||||
|
--name headscale-test-suite \
|
||||||
|
--volume $PWD:$PWD -w $PWD/integration \
|
||||||
|
--volume /var/run/docker.sock:/var/run/docker.sock \
|
||||||
|
--volume $PWD/control_logs:/tmp/control \
|
||||||
|
golang:1 \
|
||||||
|
go run gotest.tools/gotestsum@latest -- ./... \
|
||||||
|
-failfast \
|
||||||
|
-timeout 120m \
|
||||||
|
-parallel 1 \
|
||||||
|
-run "^TestSubnetRouteACL$"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
with:
|
||||||
|
name: logs
|
||||||
|
path: "control_logs/*.log"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
if: always() && steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
with:
|
||||||
|
name: pprof
|
||||||
|
path: "control_logs/*.pprof.tar"
|
28
CHANGELOG.md
28
CHANGELOG.md
|
@ -26,7 +26,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/
|
||||||
- Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473)
|
- Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473)
|
||||||
- API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553)
|
- API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553)
|
||||||
- Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611)
|
- Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611)
|
||||||
- The latest supported client is 1.36
|
- The latest supported client is 1.38
|
||||||
- Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564)
|
- Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564)
|
||||||
- If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url.
|
- If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url.
|
||||||
- Embedded DERP server requires a private key [#1611](https://github.com/juanfont/headscale/pull/1611)
|
- Embedded DERP server requires a private key [#1611](https://github.com/juanfont/headscale/pull/1611)
|
||||||
|
@ -34,17 +34,21 @@ after improving the test harness as part of adopting [#1460](https://github.com/
|
||||||
|
|
||||||
### Changes
|
### Changes
|
||||||
|
|
||||||
Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644)
|
- Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644)
|
||||||
Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
|
- Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
|
||||||
Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287)
|
- SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
|
||||||
SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
|
- State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
|
||||||
State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
|
- Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460)
|
||||||
Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460)
|
- Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
|
||||||
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
|
- Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
|
||||||
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
|
- Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
|
||||||
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
|
- Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259)
|
||||||
Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259)
|
- Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565)
|
||||||
Add `oidc.groups_claim`, `oidc.email_claim`, and `oidc.username_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)
|
- Change the structure of database configuration, see [config-example.yaml](./config-example.yaml) for the new structure. [#1700](https://github.com/juanfont/headscale/pull/1700)
|
||||||
|
- Old structure is now considered deprecated and will be removed in the future.
|
||||||
|
- Adds additional configuration for PostgreSQL for setting max open, idle conection and idle connection lifetime.
|
||||||
|
- Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287)
|
||||||
|
- Add `oidc.groups_claim`, `oidc.email_claim`, and `oidc.username_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)
|
||||||
|
|
||||||
## 0.22.3 (2023-05-12)
|
## 0.22.3 (2023-05-12)
|
||||||
|
|
||||||
|
|
|
@ -6,25 +6,11 @@ import (
|
||||||
|
|
||||||
"github.com/efekarakus/termcolor"
|
"github.com/efekarakus/termcolor"
|
||||||
"github.com/juanfont/headscale/cmd/headscale/cli"
|
"github.com/juanfont/headscale/cmd/headscale/cli"
|
||||||
"github.com/pkg/profile"
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
|
|
||||||
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
|
|
||||||
err := os.MkdirAll(profilePath, os.ModePerm)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal().Err(err).Msg("failed to create profiling directory")
|
|
||||||
}
|
|
||||||
|
|
||||||
defer profile.Start(profile.ProfilePath(profilePath)).Stop()
|
|
||||||
} else {
|
|
||||||
defer profile.Start().Stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var colors bool
|
var colors bool
|
||||||
switch l := termcolor.SupportLevel(os.Stderr); l {
|
switch l := termcolor.SupportLevel(os.Stderr); l {
|
||||||
case termcolor.Level16M:
|
case termcolor.Level16M:
|
||||||
|
|
|
@ -58,8 +58,10 @@ func (*Suite) TestConfigFileLoading(c *check.C) {
|
||||||
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
|
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
|
||||||
c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
|
c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
|
||||||
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
|
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
|
||||||
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
|
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite")
|
||||||
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
|
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
|
||||||
|
c.Assert(viper.GetString("database.type"), check.Equals, "sqlite")
|
||||||
|
c.Assert(viper.GetString("database.sqlite.path"), check.Equals, "/var/lib/headscale/db.sqlite")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
|
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
||||||
|
@ -101,7 +103,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
|
||||||
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
|
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
|
||||||
c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
|
c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
|
||||||
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
|
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
|
||||||
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
|
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite")
|
||||||
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
|
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
|
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
|
||||||
|
|
|
@ -94,6 +94,16 @@ derp:
|
||||||
#
|
#
|
||||||
private_key_path: /var/lib/headscale/derp_server_private.key
|
private_key_path: /var/lib/headscale/derp_server_private.key
|
||||||
|
|
||||||
|
# This flag can be used, so the DERP map entry for the embedded DERP server is not written automatically,
|
||||||
|
# it enables the creation of your very own DERP map entry using a locally available file with the parameter DERP.paths
|
||||||
|
# If you enable the DERP server and set this to false, it is required to add the DERP server to the DERP map using DERP.paths
|
||||||
|
automatically_add_embedded_derp_region: true
|
||||||
|
|
||||||
|
# For better connection stability (especially when using an Exit-Node and DNS is not working),
|
||||||
|
# it is possible to optionall add the public IPv4 and IPv6 address to the Derp-Map using:
|
||||||
|
ipv4: 1.2.3.4
|
||||||
|
ipv6: 2001:db8::1
|
||||||
|
|
||||||
# List of externally available DERP maps encoded in JSON
|
# List of externally available DERP maps encoded in JSON
|
||||||
urls:
|
urls:
|
||||||
- https://controlplane.tailscale.com/derpmap/default
|
- https://controlplane.tailscale.com/derpmap/default
|
||||||
|
@ -128,24 +138,28 @@ ephemeral_node_inactivity_timeout: 30m
|
||||||
# In case of doubts, do not touch the default 10s.
|
# In case of doubts, do not touch the default 10s.
|
||||||
node_update_check_interval: 10s
|
node_update_check_interval: 10s
|
||||||
|
|
||||||
# SQLite config
|
database:
|
||||||
db_type: sqlite3
|
type: sqlite
|
||||||
|
|
||||||
# For production:
|
# SQLite config
|
||||||
db_path: /var/lib/headscale/db.sqlite
|
sqlite:
|
||||||
|
path: /var/lib/headscale/db.sqlite
|
||||||
|
|
||||||
# # Postgres config
|
# # Postgres config
|
||||||
# If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank.
|
# postgres:
|
||||||
# db_type: postgres
|
# # If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank.
|
||||||
# db_host: localhost
|
# host: localhost
|
||||||
# db_port: 5432
|
# port: 5432
|
||||||
# db_name: headscale
|
# name: headscale
|
||||||
# db_user: foo
|
# user: foo
|
||||||
# db_pass: bar
|
# pass: bar
|
||||||
|
# max_open_conns: 10
|
||||||
|
# max_idle_conns: 10
|
||||||
|
# conn_max_idle_time_secs: 3600
|
||||||
|
|
||||||
# If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need
|
# # If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need
|
||||||
# in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1.
|
# # in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1.
|
||||||
# db_ssl: false
|
# ssl: false
|
||||||
|
|
||||||
### TLS configuration
|
### TLS configuration
|
||||||
#
|
#
|
||||||
|
|
|
@ -18,6 +18,7 @@ You can set these using the Windows Registry Editor:
|
||||||
Or via the following Powershell commands (right click Powershell icon and select "Run as administrator"):
|
Or via the following Powershell commands (right click Powershell icon and select "Run as administrator"):
|
||||||
|
|
||||||
```
|
```
|
||||||
|
New-Item -Path "HKLM:\SOFTWARE\Tailscale IPN"
|
||||||
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always
|
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always
|
||||||
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value https://YOUR-HEADSCALE-URL
|
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value https://YOUR-HEADSCALE-URL
|
||||||
```
|
```
|
||||||
|
|
109
hscontrol/app.go
109
hscontrol/app.go
|
@ -12,7 +12,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
@ -33,6 +32,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||||
|
"github.com/pkg/profile"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
zl "github.com/rs/zerolog"
|
zl "github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -48,6 +48,7 @@ import (
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/reflection"
|
"google.golang.org/grpc/reflection"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
"gorm.io/gorm"
|
||||||
"tailscale.com/envknob"
|
"tailscale.com/envknob"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
|
@ -61,7 +62,7 @@ var (
|
||||||
"unknown value for Lets Encrypt challenge type",
|
"unknown value for Lets Encrypt challenge type",
|
||||||
)
|
)
|
||||||
errEmptyInitialDERPMap = errors.New(
|
errEmptyInitialDERPMap = errors.New(
|
||||||
"initial DERPMap is empty, Headscale requries at least one entry",
|
"initial DERPMap is empty, Headscale requires at least one entry",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -116,37 +117,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
|
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var dbString string
|
|
||||||
switch cfg.DBtype {
|
|
||||||
case db.Postgres:
|
|
||||||
dbString = fmt.Sprintf(
|
|
||||||
"host=%s dbname=%s user=%s",
|
|
||||||
cfg.DBhost,
|
|
||||||
cfg.DBname,
|
|
||||||
cfg.DBuser,
|
|
||||||
)
|
|
||||||
|
|
||||||
if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil {
|
|
||||||
if !sslEnabled {
|
|
||||||
dbString += " sslmode=disable"
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.DBport != 0 {
|
|
||||||
dbString += fmt.Sprintf(" port=%d", cfg.DBport)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.DBpass != "" {
|
|
||||||
dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
|
|
||||||
}
|
|
||||||
case db.Sqlite:
|
|
||||||
dbString = cfg.DBpath
|
|
||||||
default:
|
|
||||||
return nil, errUnsupportedDatabase
|
|
||||||
}
|
|
||||||
|
|
||||||
registrationCache := cache.New(
|
registrationCache := cache.New(
|
||||||
registerCacheExpiration,
|
registerCacheExpiration,
|
||||||
registerCacheCleanup,
|
registerCacheCleanup,
|
||||||
|
@ -154,8 +124,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
|
|
||||||
app := Headscale{
|
app := Headscale{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
dbType: cfg.DBtype,
|
|
||||||
dbString: dbString,
|
|
||||||
noisePrivateKey: noisePrivateKey,
|
noisePrivateKey: noisePrivateKey,
|
||||||
registrationCache: registrationCache,
|
registrationCache: registrationCache,
|
||||||
pollNetMapStreamWG: sync.WaitGroup{},
|
pollNetMapStreamWG: sync.WaitGroup{},
|
||||||
|
@ -163,9 +131,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
database, err := db.NewHeadscaleDatabase(
|
database, err := db.NewHeadscaleDatabase(
|
||||||
cfg.DBtype,
|
cfg.Database,
|
||||||
dbString,
|
|
||||||
app.dbDebug,
|
|
||||||
app.nodeNotifier,
|
app.nodeNotifier,
|
||||||
cfg.IPPrefixes,
|
cfg.IPPrefixes,
|
||||||
cfg.BaseDomain)
|
cfg.BaseDomain)
|
||||||
|
@ -234,8 +200,23 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
|
||||||
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
|
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
|
||||||
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
||||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||||
|
|
||||||
|
var update types.StateUpdate
|
||||||
|
var changed bool
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
h.db.ExpireEphemeralNodes(h.cfg.EphemeralNodeInactivityTimeout)
|
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if changed && update.Valid() {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
|
||||||
|
h.nodeNotifier.NotifyAll(ctx, update)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,9 +227,24 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
|
||||||
ticker := time.NewTicker(interval)
|
ticker := time.NewTicker(interval)
|
||||||
|
|
||||||
lastCheck := time.Unix(0, 0)
|
lastCheck := time.Unix(0, 0)
|
||||||
|
var update types.StateUpdate
|
||||||
|
var changed bool
|
||||||
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
lastCheck = h.db.ExpireExpiredNodes(lastCheck)
|
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
log.Error().Err(err).Msg("database error while expiring nodes")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes")
|
||||||
|
if changed && update.Valid() {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
|
||||||
|
h.nodeNotifier.NotifyAll(ctx, update)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -268,7 +264,7 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
log.Info().Msg("Fetching DERPMap updates")
|
log.Info().Msg("Fetching DERPMap updates")
|
||||||
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
||||||
if h.cfg.DERP.ServerEnabled {
|
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||||
region, _ := h.DERPServer.GenerateRegion()
|
region, _ := h.DERPServer.GenerateRegion()
|
||||||
h.DERPMap.Regions[region.RegionID] = ®ion
|
h.DERPMap.Regions[region.RegionID] = ®ion
|
||||||
}
|
}
|
||||||
|
@ -278,7 +274,8 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
|
||||||
DERPMap: h.DERPMap,
|
DERPMap: h.DERPMap,
|
||||||
}
|
}
|
||||||
if stateUpdate.Valid() {
|
if stateUpdate.Valid() {
|
||||||
h.nodeNotifier.NotifyAll(stateUpdate)
|
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
|
||||||
|
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -485,6 +482,19 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||||
|
|
||||||
// Serve launches a GIN server with the Headscale API.
|
// Serve launches a GIN server with the Headscale API.
|
||||||
func (h *Headscale) Serve() error {
|
func (h *Headscale) Serve() error {
|
||||||
|
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
|
||||||
|
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
|
||||||
|
err := os.MkdirAll(profilePath, os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("failed to create profiling directory")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer profile.Start(profile.ProfilePath(profilePath)).Stop()
|
||||||
|
} else {
|
||||||
|
defer profile.Start().Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Fetch an initial DERP Map before we start serving
|
// Fetch an initial DERP Map before we start serving
|
||||||
|
@ -501,7 +511,9 @@ func (h *Headscale) Serve() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
h.DERPMap.Regions[region.RegionID] = ®ion
|
if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||||
|
h.DERPMap.Regions[region.RegionID] = ®ion
|
||||||
|
}
|
||||||
|
|
||||||
go h.DERPServer.ServeSTUN()
|
go h.DERPServer.ServeSTUN()
|
||||||
}
|
}
|
||||||
|
@ -708,14 +720,16 @@ func (h *Headscale) Serve() error {
|
||||||
|
|
||||||
var tailsqlContext context.Context
|
var tailsqlContext context.Context
|
||||||
if tailsqlEnabled {
|
if tailsqlEnabled {
|
||||||
if h.cfg.DBtype != db.Sqlite {
|
if h.cfg.Database.Type != types.DatabaseSqlite {
|
||||||
log.Fatal().Str("type", h.cfg.DBtype).Msgf("tailsql only support %q", db.Sqlite)
|
log.Fatal().
|
||||||
|
Str("type", h.cfg.Database.Type).
|
||||||
|
Msgf("tailsql only support %q", types.DatabaseSqlite)
|
||||||
}
|
}
|
||||||
if tailsqlTSKey == "" {
|
if tailsqlTSKey == "" {
|
||||||
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
|
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
|
||||||
}
|
}
|
||||||
tailsqlContext = context.Background()
|
tailsqlContext = context.Background()
|
||||||
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.DBpath)
|
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle common process-killing signals so we can gracefully shut down:
|
// Handle common process-killing signals so we can gracefully shut down:
|
||||||
|
@ -751,7 +765,8 @@ func (h *Headscale) Serve() error {
|
||||||
Str("path", aclPath).
|
Str("path", aclPath).
|
||||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||||
|
|
||||||
h.nodeNotifier.NotifyAll(types.StateUpdate{
|
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
|
||||||
|
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
Type: types.StateFullUpdate,
|
Type: types.StateFullUpdate,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package hscontrol
|
package hscontrol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -8,6 +9,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -199,6 +201,19 @@ func (h *Headscale) handleRegister(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When logged out and reauthenticating with OIDC, the OldNodeKey is not passed, but the NodeKey has changed
|
||||||
|
if node.NodeKey.String() != registerRequest.NodeKey.String() &&
|
||||||
|
registerRequest.OldNodeKey.IsZero() && !node.IsExpired() {
|
||||||
|
h.handleNodeKeyRefresh(
|
||||||
|
writer,
|
||||||
|
registerRequest,
|
||||||
|
*node,
|
||||||
|
machineKey,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if registerRequest.Followup != "" {
|
if registerRequest.Followup != "" {
|
||||||
select {
|
select {
|
||||||
case <-req.Context().Done():
|
case <-req.Context().Done():
|
||||||
|
@ -230,8 +245,6 @@ func (h *Headscale) handleRegister(
|
||||||
|
|
||||||
// handleAuthKey contains the logic to manage auth key client registration
|
// handleAuthKey contains the logic to manage auth key client registration
|
||||||
// When using Noise, the machineKey is Zero.
|
// When using Noise, the machineKey is Zero.
|
||||||
//
|
|
||||||
// TODO: check if any locks are needed around IP allocation.
|
|
||||||
func (h *Headscale) handleAuthKey(
|
func (h *Headscale) handleAuthKey(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
registerRequest tailcfg.RegisterRequest,
|
registerRequest tailcfg.RegisterRequest,
|
||||||
|
@ -298,6 +311,9 @@ func (h *Headscale) handleAuthKey(
|
||||||
|
|
||||||
nodeKey := registerRequest.NodeKey
|
nodeKey := registerRequest.NodeKey
|
||||||
|
|
||||||
|
var update types.StateUpdate
|
||||||
|
var mkey key.MachinePublic
|
||||||
|
|
||||||
// retrieve node information if it exist
|
// retrieve node information if it exist
|
||||||
// The error is not important, because if it does not
|
// The error is not important, because if it does not
|
||||||
// exist, then this is a new node and we will move
|
// exist, then this is a new node and we will move
|
||||||
|
@ -311,7 +327,7 @@ func (h *Headscale) handleAuthKey(
|
||||||
|
|
||||||
node.NodeKey = nodeKey
|
node.NodeKey = nodeKey
|
||||||
node.AuthKeyID = uint(pak.ID)
|
node.AuthKeyID = uint(pak.ID)
|
||||||
err := h.db.NodeSetExpiry(node, registerRequest.Expiry)
|
err := h.db.NodeSetExpiry(node.ID, registerRequest.Expiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -322,10 +338,13 @@ func (h *Headscale) handleAuthKey(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mkey = node.MachineKey
|
||||||
|
update = types.StateUpdateExpire(node.ID, registerRequest.Expiry)
|
||||||
|
|
||||||
aclTags := pak.Proto().GetAclTags()
|
aclTags := pak.Proto().GetAclTags()
|
||||||
if len(aclTags) > 0 {
|
if len(aclTags) > 0 {
|
||||||
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
||||||
err = h.db.SetTags(node, aclTags)
|
err = h.db.SetTags(node.ID, aclTags)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -357,6 +376,7 @@ func (h *Headscale) handleAuthKey(
|
||||||
Hostname: registerRequest.Hostinfo.Hostname,
|
Hostname: registerRequest.Hostinfo.Hostname,
|
||||||
GivenName: givenName,
|
GivenName: givenName,
|
||||||
UserID: pak.User.ID,
|
UserID: pak.User.ID,
|
||||||
|
User: pak.User,
|
||||||
MachineKey: machineKey,
|
MachineKey: machineKey,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
Expiry: ®isterRequest.Expiry,
|
Expiry: ®isterRequest.Expiry,
|
||||||
|
@ -380,9 +400,18 @@ func (h *Headscale) handleAuthKey(
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mkey = node.MachineKey
|
||||||
|
update = types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
ChangeNodes: types.Nodes{node},
|
||||||
|
Message: "called from auth.handleAuthKey",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.db.UsePreAuthKey(pak)
|
err = h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
return db.UsePreAuthKey(tx, pak)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -424,6 +453,13 @@ func (h *Headscale) handleAuthKey(
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed to write response")
|
Msg("Failed to write response")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): if notifying after register make sense.
|
||||||
|
if update.Valid() {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
|
||||||
|
h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
|
@ -489,7 +525,7 @@ func (h *Headscale) handleNodeLogOut(
|
||||||
Msg("Client requested logout")
|
Msg("Client requested logout")
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
err := h.db.NodeSetExpiry(&node, now)
|
err := h.db.NodeSetExpiry(node.ID, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -500,17 +536,10 @@ func (h *Headscale) handleNodeLogOut(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
stateUpdate := types.StateUpdateExpire(node.ID, now)
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: tailcfg.NodeID(node.ID),
|
|
||||||
KeyExpiry: &now,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if stateUpdate.Valid() {
|
if stateUpdate.Valid() {
|
||||||
h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
|
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
|
||||||
|
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
|
@ -541,7 +570,7 @@ func (h *Headscale) handleNodeLogOut(
|
||||||
}
|
}
|
||||||
|
|
||||||
if node.IsEphemeral() {
|
if node.IsEphemeral() {
|
||||||
err = h.db.DeleteNode(&node)
|
err = h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -549,6 +578,15 @@ func (h *Headscale) handleNodeLogOut(
|
||||||
Msg("Cannot delete ephemeral node from the database")
|
Msg("Cannot delete ephemeral node from the database")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stateUpdate := types.StateUpdate{
|
||||||
|
Type: types.StatePeerRemoved,
|
||||||
|
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
|
||||||
|
}
|
||||||
|
if stateUpdate.Valid() {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
|
||||||
|
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -620,7 +658,9 @@ func (h *Headscale) handleNodeKeyRefresh(
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||||
|
|
||||||
err := h.db.NodeSetNodeKey(&node, registerRequest.NodeKey)
|
err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
|
|
@ -13,16 +13,23 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
|
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
|
func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
|
||||||
|
return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) {
|
||||||
|
return getAvailableIPs(rx, hsdb.ipPrefixes)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) {
|
||||||
var ips types.NodeAddresses
|
var ips types.NodeAddresses
|
||||||
var err error
|
var err error
|
||||||
for _, ipPrefix := range hsdb.ipPrefixes {
|
for _, ipPrefix := range ipPrefixes {
|
||||||
var ip *netip.Addr
|
var ip *netip.Addr
|
||||||
ip, err = hsdb.getAvailableIP(ipPrefix)
|
ip, err = getAvailableIP(rx, ipPrefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ips, err
|
return ips, err
|
||||||
}
|
}
|
||||||
|
@ -32,8 +39,8 @@ func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
|
||||||
return ips, err
|
return ips, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
|
func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) {
|
||||||
usedIps, err := hsdb.getUsedIPs()
|
usedIps, err := getUsedIPs(rx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -64,12 +71,12 @@ func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, erro
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) {
|
func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||||
// FIXME: This really deserves a better data model,
|
// FIXME: This really deserves a better data model,
|
||||||
// but this was quick to get running and it should be enough
|
// but this was quick to get running and it should be enough
|
||||||
// to begin experimenting with a dual stack tailnet.
|
// to begin experimenting with a dual stack tailnet.
|
||||||
var addressesSlices []string
|
var addressesSlices []string
|
||||||
hsdb.db.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
|
rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
|
||||||
|
|
||||||
var ips netipx.IPSetBuilder
|
var ips netipx.IPSetBuilder
|
||||||
for _, slice := range addressesSlices {
|
for _, slice := range addressesSlices {
|
||||||
|
|
|
@ -7,10 +7,16 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||||
ips, err := db.getAvailableIPs()
|
tx := db.DB.Begin()
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
ips, err := getAvailableIPs(tx, []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.27.0.0/23"),
|
||||||
|
})
|
||||||
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
@ -30,7 +36,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("test", "testnode")
|
_, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -41,10 +47,13 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
IPAddresses: ips,
|
IPAddresses: ips,
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.Write(func(tx *gorm.DB) error {
|
||||||
|
return tx.Save(&node).Error
|
||||||
usedIps, err := db.getUsedIPs()
|
})
|
||||||
|
|
||||||
|
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||||
|
return getUsedIPs(rx)
|
||||||
|
})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
expected := netip.MustParseAddr("10.27.0.1")
|
expected := netip.MustParseAddr("10.27.0.1")
|
||||||
|
@ -63,19 +72,23 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
user, err := db.CreateUser("test-ip-multi")
|
user, err := db.CreateUser("test-ip")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
ipPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.27.0.0/23"),
|
||||||
|
}
|
||||||
|
|
||||||
for index := 1; index <= 350; index++ {
|
for index := 1; index <= 350; index++ {
|
||||||
db.ipAllocationMutex.Lock()
|
tx := db.DB.Begin()
|
||||||
|
|
||||||
ips, err := db.getAvailableIPs()
|
ips, err := getAvailableIPs(tx, ipPrefixes)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("test", "testnode")
|
_, err = getNode(tx, "test", "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -86,12 +99,13 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
IPAddresses: ips,
|
IPAddresses: ips,
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
tx.Save(&node)
|
||||||
|
c.Assert(tx.Commit().Error, check.IsNil)
|
||||||
db.ipAllocationMutex.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
usedIps, err := db.getUsedIPs()
|
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
|
||||||
|
return getUsedIPs(rx)
|
||||||
|
})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
expected0 := netip.MustParseAddr("10.27.0.1")
|
expected0 := netip.MustParseAddr("10.27.0.1")
|
||||||
|
@ -162,7 +176,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("test", "testnode")
|
_, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -172,7 +186,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
ips2, err := db.getAvailableIPs()
|
ips2, err := db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
|
@ -22,9 +22,6 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
|
||||||
func (hsdb *HSDatabase) CreateAPIKey(
|
func (hsdb *HSDatabase) CreateAPIKey(
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
) (string, *types.APIKey, error) {
|
) (string, *types.APIKey, error) {
|
||||||
hsdb.mu.Lock()
|
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
|
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
|
@ -49,7 +46,7 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||||
Expiration: expiration,
|
Expiration: expiration,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := hsdb.db.Save(&key).Error; err != nil {
|
if err := hsdb.DB.Save(&key).Error; err != nil {
|
||||||
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
|
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,11 +55,8 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||||
|
|
||||||
// ListAPIKeys returns the list of ApiKeys for a user.
|
// ListAPIKeys returns the list of ApiKeys for a user.
|
||||||
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
keys := []types.APIKey{}
|
keys := []types.APIKey{}
|
||||||
if err := hsdb.db.Find(&keys).Error; err != nil {
|
if err := hsdb.DB.Find(&keys).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,11 +65,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
||||||
|
|
||||||
// GetAPIKey returns a ApiKey for a given key.
|
// GetAPIKey returns a ApiKey for a given key.
|
||||||
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
|
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
key := types.APIKey{}
|
key := types.APIKey{}
|
||||||
if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil {
|
if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,11 +75,8 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
|
||||||
|
|
||||||
// GetAPIKeyByID returns a ApiKey for a given id.
|
// GetAPIKeyByID returns a ApiKey for a given id.
|
||||||
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
|
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
key := types.APIKey{}
|
key := types.APIKey{}
|
||||||
if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
|
if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,10 +86,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
|
||||||
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
|
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
|
||||||
// does not exist.
|
// does not exist.
|
||||||
func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
||||||
hsdb.mu.Lock()
|
if result := hsdb.DB.Unscoped().Delete(key); result.Error != nil {
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
if result := hsdb.db.Unscoped().Delete(key); result.Error != nil {
|
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,10 +95,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
||||||
|
|
||||||
// ExpireAPIKey marks a ApiKey as expired.
|
// ExpireAPIKey marks a ApiKey as expired.
|
||||||
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
||||||
hsdb.mu.Lock()
|
if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,9 +103,6 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
|
func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
prefix, hash, found := strings.Cut(keyStr, ".")
|
prefix, hash, found := strings.Cut(keyStr, ".")
|
||||||
if !found {
|
if !found {
|
||||||
return false, ErrAPIKeyFailedToParse
|
return false, ErrAPIKeyFailedToParse
|
||||||
|
|
|
@ -6,24 +6,20 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/glebarez/sqlite"
|
"github.com/glebarez/sqlite"
|
||||||
"github.com/go-gormigrate/gormigrate/v2"
|
"github.com/go-gormigrate/gormigrate/v2"
|
||||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
Postgres = "postgres"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
Sqlite = "sqlite3"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errDatabaseNotSupported = errors.New("database type not supported")
|
var errDatabaseNotSupported = errors.New("database type not supported")
|
||||||
|
@ -36,12 +32,7 @@ type KV struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type HSDatabase struct {
|
type HSDatabase struct {
|
||||||
db *gorm.DB
|
DB *gorm.DB
|
||||||
notifier *notifier.Notifier
|
|
||||||
|
|
||||||
mu sync.RWMutex
|
|
||||||
|
|
||||||
ipAllocationMutex sync.Mutex
|
|
||||||
|
|
||||||
ipPrefixes []netip.Prefix
|
ipPrefixes []netip.Prefix
|
||||||
baseDomain string
|
baseDomain string
|
||||||
|
@ -50,275 +41,290 @@ type HSDatabase struct {
|
||||||
// TODO(kradalby): assemble this struct from toptions or something typed
|
// TODO(kradalby): assemble this struct from toptions or something typed
|
||||||
// rather than arguments.
|
// rather than arguments.
|
||||||
func NewHeadscaleDatabase(
|
func NewHeadscaleDatabase(
|
||||||
dbType, connectionAddr string,
|
cfg types.DatabaseConfig,
|
||||||
debug bool,
|
|
||||||
notifier *notifier.Notifier,
|
notifier *notifier.Notifier,
|
||||||
ipPrefixes []netip.Prefix,
|
ipPrefixes []netip.Prefix,
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
) (*HSDatabase, error) {
|
) (*HSDatabase, error) {
|
||||||
dbConn, err := openDB(dbType, connectionAddr, debug)
|
dbConn, err := openDB(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
migrations := gormigrate.New(dbConn, gormigrate.DefaultOptions, []*gormigrate.Migration{
|
migrations := gormigrate.New(
|
||||||
// New migrations should be added as transactions at the end of this list.
|
dbConn,
|
||||||
// The initial commit here is quite messy, completely out of order and
|
gormigrate.DefaultOptions,
|
||||||
// has no versioning and is the tech debt of not having versioned migrations
|
[]*gormigrate.Migration{
|
||||||
// prior to this point. This first migration is all DB changes to bring a DB
|
// New migrations should be added as transactions at the end of this list.
|
||||||
// up to 0.23.0.
|
// The initial commit here is quite messy, completely out of order and
|
||||||
{
|
// has no versioning and is the tech debt of not having versioned migrations
|
||||||
ID: "202312101416",
|
// prior to this point. This first migration is all DB changes to bring a DB
|
||||||
Migrate: func(tx *gorm.DB) error {
|
// up to 0.23.0.
|
||||||
if dbType == Postgres {
|
{
|
||||||
tx.Exec(`create extension if not exists "uuid-ossp";`)
|
ID: "202312101416",
|
||||||
}
|
Migrate: func(tx *gorm.DB) error {
|
||||||
|
if cfg.Type == types.DatabasePostgres {
|
||||||
_ = tx.Migrator().RenameTable("namespaces", "users")
|
tx.Exec(`create extension if not exists "uuid-ossp";`)
|
||||||
|
|
||||||
// the big rename from Machine to Node
|
|
||||||
_ = tx.Migrator().RenameTable("machines", "nodes")
|
|
||||||
_ = tx.Migrator().RenameColumn(&types.Route{}, "machine_id", "node_id")
|
|
||||||
|
|
||||||
err = tx.AutoMigrate(types.User{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = tx.Migrator().RenameColumn(&types.Node{}, "namespace_id", "user_id")
|
|
||||||
_ = tx.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
|
|
||||||
|
|
||||||
_ = tx.Migrator().RenameColumn(&types.Node{}, "ip_address", "ip_addresses")
|
|
||||||
_ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname")
|
|
||||||
|
|
||||||
// GivenName is used as the primary source of DNS names, make sure
|
|
||||||
// the field is populated and normalized if it was not when the
|
|
||||||
// node was registered.
|
|
||||||
_ = tx.Migrator().RenameColumn(&types.Node{}, "nickname", "given_name")
|
|
||||||
|
|
||||||
// If the Node table has a column for registered,
|
|
||||||
// find all occourences of "false" and drop them. Then
|
|
||||||
// remove the column.
|
|
||||||
if tx.Migrator().HasColumn(&types.Node{}, "registered") {
|
|
||||||
log.Info().
|
|
||||||
Msg(`Database has legacy "registered" column in node, removing...`)
|
|
||||||
|
|
||||||
nodes := types.Nodes{}
|
|
||||||
if err := tx.Not("registered").Find(&nodes).Error; err != nil {
|
|
||||||
log.Error().Err(err).Msg("Error accessing db")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, node := range nodes {
|
_ = tx.Migrator().RenameTable("namespaces", "users")
|
||||||
log.Info().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Str("machine_key", node.MachineKey.ShortString()).
|
|
||||||
Msg("Deleting unregistered node")
|
|
||||||
if err := tx.Delete(&types.Node{}, node.ID).Error; err != nil {
|
|
||||||
log.Error().
|
|
||||||
Err(err).
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Str("machine_key", node.MachineKey.ShortString()).
|
|
||||||
Msg("Error deleting unregistered node")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := tx.Migrator().DropColumn(&types.Node{}, "registered")
|
// the big rename from Machine to Node
|
||||||
if err != nil {
|
_ = tx.Migrator().RenameTable("machines", "nodes")
|
||||||
log.Error().Err(err).Msg("Error dropping registered column")
|
_ = tx.Migrator().
|
||||||
}
|
RenameColumn(&types.Route{}, "machine_id", "node_id")
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.AutoMigrate(&types.Route{})
|
err = tx.AutoMigrate(types.User{})
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.AutoMigrate(&types.Node{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure all keys have correct prefixes
|
|
||||||
// https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35
|
|
||||||
type result struct {
|
|
||||||
ID uint64
|
|
||||||
MachineKey string
|
|
||||||
NodeKey string
|
|
||||||
DiscoKey string
|
|
||||||
}
|
|
||||||
var results []result
|
|
||||||
err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, node := range results {
|
|
||||||
mKey := node.MachineKey
|
|
||||||
if !strings.HasPrefix(node.MachineKey, "mkey:") {
|
|
||||||
mKey = "mkey:" + node.MachineKey
|
|
||||||
}
|
|
||||||
nKey := node.NodeKey
|
|
||||||
if !strings.HasPrefix(node.NodeKey, "nodekey:") {
|
|
||||||
nKey = "nodekey:" + node.NodeKey
|
|
||||||
}
|
|
||||||
|
|
||||||
dKey := node.DiscoKey
|
|
||||||
if !strings.HasPrefix(node.DiscoKey, "discokey:") {
|
|
||||||
dKey = "discokey:" + node.DiscoKey
|
|
||||||
}
|
|
||||||
|
|
||||||
err := tx.Exec(
|
|
||||||
"UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id",
|
|
||||||
sql.Named("mKey", mKey),
|
|
||||||
sql.Named("nKey", nKey),
|
|
||||||
sql.Named("dKey", dKey),
|
|
||||||
sql.Named("id", node.ID),
|
|
||||||
).Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") {
|
_ = tx.Migrator().
|
||||||
log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...")
|
RenameColumn(&types.Node{}, "namespace_id", "user_id")
|
||||||
|
_ = tx.Migrator().
|
||||||
|
RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
|
||||||
|
|
||||||
type NodeAux struct {
|
_ = tx.Migrator().
|
||||||
ID uint64
|
RenameColumn(&types.Node{}, "ip_address", "ip_addresses")
|
||||||
EnabledRoutes types.IPPrefixes
|
_ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname")
|
||||||
}
|
|
||||||
|
|
||||||
nodesAux := []NodeAux{}
|
// GivenName is used as the primary source of DNS names, make sure
|
||||||
err := tx.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error
|
// the field is populated and normalized if it was not when the
|
||||||
if err != nil {
|
// node was registered.
|
||||||
log.Fatal().Err(err).Msg("Error accessing db")
|
_ = tx.Migrator().
|
||||||
}
|
RenameColumn(&types.Node{}, "nickname", "given_name")
|
||||||
for _, node := range nodesAux {
|
|
||||||
for _, prefix := range node.EnabledRoutes {
|
// If the Node table has a column for registered,
|
||||||
if err != nil {
|
// find all occourences of "false" and drop them. Then
|
||||||
|
// remove the column.
|
||||||
|
if tx.Migrator().HasColumn(&types.Node{}, "registered") {
|
||||||
|
log.Info().
|
||||||
|
Msg(`Database has legacy "registered" column in node, removing...`)
|
||||||
|
|
||||||
|
nodes := types.Nodes{}
|
||||||
|
if err := tx.Not("registered").Find(&nodes).Error; err != nil {
|
||||||
|
log.Error().Err(err).Msg("Error accessing db")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, node := range nodes {
|
||||||
|
log.Info().
|
||||||
|
Str("node", node.Hostname).
|
||||||
|
Str("machine_key", node.MachineKey.ShortString()).
|
||||||
|
Msg("Deleting unregistered node")
|
||||||
|
if err := tx.Delete(&types.Node{}, node.ID).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("enabled_route", prefix.String()).
|
Str("node", node.Hostname).
|
||||||
Msg("Error parsing enabled_route")
|
Str("machine_key", node.MachineKey.ShortString()).
|
||||||
|
Msg("Error deleting unregistered node")
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = tx.Preload("Node").
|
err := tx.Migrator().DropColumn(&types.Node{}, "registered")
|
||||||
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
|
if err != nil {
|
||||||
First(&types.Route{}).
|
log.Error().Err(err).Msg("Error dropping registered column")
|
||||||
Error
|
}
|
||||||
if err == nil {
|
}
|
||||||
log.Info().
|
|
||||||
Str("enabled_route", prefix.String()).
|
|
||||||
Msg("Route already migrated to new table, skipping")
|
|
||||||
|
|
||||||
continue
|
err = tx.AutoMigrate(&types.Route{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.AutoMigrate(&types.Node{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure all keys have correct prefixes
|
||||||
|
// https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35
|
||||||
|
type result struct {
|
||||||
|
ID uint64
|
||||||
|
MachineKey string
|
||||||
|
NodeKey string
|
||||||
|
DiscoKey string
|
||||||
|
}
|
||||||
|
var results []result
|
||||||
|
err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").
|
||||||
|
Find(&results).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, node := range results {
|
||||||
|
mKey := node.MachineKey
|
||||||
|
if !strings.HasPrefix(node.MachineKey, "mkey:") {
|
||||||
|
mKey = "mkey:" + node.MachineKey
|
||||||
|
}
|
||||||
|
nKey := node.NodeKey
|
||||||
|
if !strings.HasPrefix(node.NodeKey, "nodekey:") {
|
||||||
|
nKey = "nodekey:" + node.NodeKey
|
||||||
|
}
|
||||||
|
|
||||||
|
dKey := node.DiscoKey
|
||||||
|
if !strings.HasPrefix(node.DiscoKey, "discokey:") {
|
||||||
|
dKey = "discokey:" + node.DiscoKey
|
||||||
|
}
|
||||||
|
|
||||||
|
err := tx.Exec(
|
||||||
|
"UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id",
|
||||||
|
sql.Named("mKey", mKey),
|
||||||
|
sql.Named("nKey", nKey),
|
||||||
|
sql.Named("dKey", dKey),
|
||||||
|
sql.Named("id", node.ID),
|
||||||
|
).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") {
|
||||||
|
log.Info().
|
||||||
|
Msgf("Database has legacy enabled_routes column in node, migrating...")
|
||||||
|
|
||||||
|
type NodeAux struct {
|
||||||
|
ID uint64
|
||||||
|
EnabledRoutes types.IPPrefixes
|
||||||
|
}
|
||||||
|
|
||||||
|
nodesAux := []NodeAux{}
|
||||||
|
err := tx.Table("nodes").
|
||||||
|
Select("id, enabled_routes").
|
||||||
|
Scan(&nodesAux).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("Error accessing db")
|
||||||
|
}
|
||||||
|
for _, node := range nodesAux {
|
||||||
|
for _, prefix := range node.EnabledRoutes {
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Err(err).
|
||||||
|
Str("enabled_route", prefix.String()).
|
||||||
|
Msg("Error parsing enabled_route")
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Preload("Node").
|
||||||
|
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
|
||||||
|
First(&types.Route{}).
|
||||||
|
Error
|
||||||
|
if err == nil {
|
||||||
|
log.Info().
|
||||||
|
Str("enabled_route", prefix.String()).
|
||||||
|
Msg("Route already migrated to new table, skipping")
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
route := types.Route{
|
||||||
|
NodeID: node.ID,
|
||||||
|
Advertised: true,
|
||||||
|
Enabled: true,
|
||||||
|
Prefix: types.IPPrefix(prefix),
|
||||||
|
}
|
||||||
|
if err := tx.Create(&route).Error; err != nil {
|
||||||
|
log.Error().Err(err).Msg("Error creating route")
|
||||||
|
} else {
|
||||||
|
log.Info().
|
||||||
|
Uint64("node_id", route.NodeID).
|
||||||
|
Str("prefix", prefix.String()).
|
||||||
|
Msg("Route migrated")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
route := types.Route{
|
err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes")
|
||||||
NodeID: node.ID,
|
if err != nil {
|
||||||
Advertised: true,
|
log.Error().
|
||||||
Enabled: true,
|
Err(err).
|
||||||
Prefix: types.IPPrefix(prefix),
|
Msg("Error dropping enabled_routes column")
|
||||||
}
|
}
|
||||||
if err := tx.Create(&route).Error; err != nil {
|
}
|
||||||
log.Error().Err(err).Msg("Error creating route")
|
|
||||||
} else {
|
if tx.Migrator().HasColumn(&types.Node{}, "given_name") {
|
||||||
log.Info().
|
nodes := types.Nodes{}
|
||||||
Uint64("node_id", route.NodeID).
|
if err := tx.Find(&nodes).Error; err != nil {
|
||||||
Str("prefix", prefix.String()).
|
log.Error().Err(err).Msg("Error accessing db")
|
||||||
Msg("Route migrated")
|
}
|
||||||
|
|
||||||
|
for item, node := range nodes {
|
||||||
|
if node.GivenName == "" {
|
||||||
|
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
|
||||||
|
node.Hostname,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Str("hostname", node.Hostname).
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to normalize node hostname in DB migration")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Model(nodes[item]).Updates(types.Node{
|
||||||
|
GivenName: normalizedHostname,
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Str("hostname", node.Hostname).
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to save normalized node name in DB migration")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes")
|
err = tx.AutoMigrate(&KV{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error dropping enabled_routes column")
|
return err
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tx.Migrator().HasColumn(&types.Node{}, "given_name") {
|
|
||||||
nodes := types.Nodes{}
|
|
||||||
if err := tx.Find(&nodes).Error; err != nil {
|
|
||||||
log.Error().Err(err).Msg("Error accessing db")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for item, node := range nodes {
|
err = tx.AutoMigrate(&types.PreAuthKey{})
|
||||||
if node.GivenName == "" {
|
if err != nil {
|
||||||
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
|
return err
|
||||||
node.Hostname,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("hostname", node.Hostname).
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to normalize node hostname in DB migration")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Model(nodes[item]).Updates(types.Node{
|
|
||||||
GivenName: normalizedHostname,
|
|
||||||
}).Error
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("hostname", node.Hostname).
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to save normalized node name in DB migration")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.AutoMigrate(&KV{})
|
err = tx.AutoMigrate(&types.PreAuthKeyACLTag{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.AutoMigrate(&types.PreAuthKey{})
|
_ = tx.Migrator().DropTable("shared_machines")
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.AutoMigrate(&types.PreAuthKeyACLTag{})
|
err = tx.AutoMigrate(&types.APIKey{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = tx.Migrator().DropTable("shared_machines")
|
return nil
|
||||||
|
},
|
||||||
err = tx.AutoMigrate(&types.APIKey{})
|
Rollback: func(tx *gorm.DB) error {
|
||||||
if err != nil {
|
return nil
|
||||||
return err
|
},
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
Rollback: func(tx *gorm.DB) error {
|
{
|
||||||
return nil
|
// drop key-value table, it is not used, and has not contained
|
||||||
|
// useful data for a long time or ever.
|
||||||
|
ID: "202312101430",
|
||||||
|
Migrate: func(tx *gorm.DB) error {
|
||||||
|
return tx.Migrator().DropTable("kvs")
|
||||||
|
},
|
||||||
|
Rollback: func(tx *gorm.DB) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
)
|
||||||
// drop key-value table, it is not used, and has not contained
|
|
||||||
// useful data for a long time or ever.
|
|
||||||
ID: "202312101430",
|
|
||||||
Migrate: func(tx *gorm.DB) error {
|
|
||||||
return tx.Migrator().DropTable("kvs")
|
|
||||||
},
|
|
||||||
Rollback: func(tx *gorm.DB) error {
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
if err = migrations.Migrate(); err != nil {
|
if err = migrations.Migrate(); err != nil {
|
||||||
log.Fatal().Err(err).Msgf("Migration failed: %v", err)
|
log.Fatal().Err(err).Msgf("Migration failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db := HSDatabase{
|
db := HSDatabase{
|
||||||
db: dbConn,
|
DB: dbConn,
|
||||||
notifier: notifier,
|
|
||||||
|
|
||||||
ipPrefixes: ipPrefixes,
|
ipPrefixes: ipPrefixes,
|
||||||
baseDomain: baseDomain,
|
baseDomain: baseDomain,
|
||||||
|
@ -327,20 +333,19 @@ func NewHeadscaleDatabase(
|
||||||
return &db, err
|
return &db, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
||||||
log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database")
|
// TODO(kradalby): Integrate this with zerolog
|
||||||
|
|
||||||
var dbLogger logger.Interface
|
var dbLogger logger.Interface
|
||||||
if debug {
|
if cfg.Debug {
|
||||||
dbLogger = logger.Default
|
dbLogger = logger.Default
|
||||||
} else {
|
} else {
|
||||||
dbLogger = logger.Default.LogMode(logger.Silent)
|
dbLogger = logger.Default.LogMode(logger.Silent)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch dbType {
|
switch cfg.Type {
|
||||||
case Sqlite:
|
case types.DatabaseSqlite:
|
||||||
db, err := gorm.Open(
|
db, err := gorm.Open(
|
||||||
sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"),
|
sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
|
||||||
&gorm.Config{
|
&gorm.Config{
|
||||||
DisableForeignKeyConstraintWhenMigrating: true,
|
DisableForeignKeyConstraintWhenMigrating: true,
|
||||||
Logger: dbLogger,
|
Logger: dbLogger,
|
||||||
|
@ -359,16 +364,51 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
||||||
|
|
||||||
return db, err
|
return db, err
|
||||||
|
|
||||||
case Postgres:
|
case types.DatabasePostgres:
|
||||||
return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{
|
dbString := fmt.Sprintf(
|
||||||
|
"host=%s dbname=%s user=%s",
|
||||||
|
cfg.Postgres.Host,
|
||||||
|
cfg.Postgres.Name,
|
||||||
|
cfg.Postgres.User,
|
||||||
|
)
|
||||||
|
|
||||||
|
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
|
||||||
|
if !sslEnabled {
|
||||||
|
dbString += " sslmode=disable"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dbString += fmt.Sprintf(" sslmode=%s", cfg.Postgres.Ssl)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Postgres.Port != 0 {
|
||||||
|
dbString += fmt.Sprintf(" port=%d", cfg.Postgres.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Postgres.Pass != "" {
|
||||||
|
dbString += fmt.Sprintf(" password=%s", cfg.Postgres.Pass)
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := gorm.Open(postgres.Open(dbString), &gorm.Config{
|
||||||
DisableForeignKeyConstraintWhenMigrating: true,
|
DisableForeignKeyConstraintWhenMigrating: true,
|
||||||
Logger: dbLogger,
|
Logger: dbLogger,
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB, _ := db.DB()
|
||||||
|
sqlDB.SetMaxIdleConns(cfg.Postgres.MaxIdleConnections)
|
||||||
|
sqlDB.SetMaxOpenConns(cfg.Postgres.MaxOpenConnections)
|
||||||
|
sqlDB.SetConnMaxIdleTime(
|
||||||
|
time.Duration(cfg.Postgres.ConnMaxIdleTimeSecs) * time.Second,
|
||||||
|
)
|
||||||
|
|
||||||
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"database of type %s is not supported: %w",
|
"database of type %s is not supported: %w",
|
||||||
dbType,
|
cfg.Type,
|
||||||
errDatabaseNotSupported,
|
errDatabaseNotSupported,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -376,7 +416,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
||||||
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
sqlDB, err := hsdb.db.DB()
|
sqlDB, err := hsdb.DB.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -385,10 +425,48 @@ func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) Close() error {
|
func (hsdb *HSDatabase) Close() error {
|
||||||
db, err := hsdb.db.DB()
|
db, err := hsdb.DB.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.Close()
|
return db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
|
||||||
|
rx := hsdb.DB.Begin()
|
||||||
|
defer rx.Rollback()
|
||||||
|
return fn(rx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
|
||||||
|
rx := db.Begin()
|
||||||
|
defer rx.Rollback()
|
||||||
|
ret, err := fn(rx)
|
||||||
|
if err != nil {
|
||||||
|
var no T
|
||||||
|
return no, err
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
|
||||||
|
tx := hsdb.DB.Begin()
|
||||||
|
defer tx.Rollback()
|
||||||
|
if err := fn(tx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit().Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
|
||||||
|
tx := db.Begin()
|
||||||
|
defer tx.Rollback()
|
||||||
|
ret, err := fn(tx)
|
||||||
|
if err != nil {
|
||||||
|
var no T
|
||||||
|
return no, err
|
||||||
|
}
|
||||||
|
return ret, tx.Commit().Error
|
||||||
|
}
|
||||||
|
|
|
@ -34,22 +34,21 @@ var (
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
|
|
||||||
func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) {
|
func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) {
|
||||||
hsdb.mu.RLock()
|
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
defer hsdb.mu.RUnlock()
|
return ListPeers(rx, node)
|
||||||
|
})
|
||||||
return hsdb.listPeers(node)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
|
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
|
||||||
|
func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Msg("Finding direct peers")
|
Msg("Finding direct peers")
|
||||||
|
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := hsdb.db.
|
if err := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
|
@ -64,16 +63,15 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
|
||||||
return nodes, nil
|
return nodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListNodes() ([]types.Node, error) {
|
func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) {
|
||||||
hsdb.mu.RLock()
|
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
defer hsdb.mu.RUnlock()
|
return ListNodes(rx)
|
||||||
|
})
|
||||||
return hsdb.listNodes()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) listNodes() ([]types.Node, error) {
|
func ListNodes(tx *gorm.DB) (types.Nodes, error) {
|
||||||
nodes := []types.Node{}
|
nodes := types.Nodes{}
|
||||||
if err := hsdb.db.
|
if err := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
|
@ -85,16 +83,9 @@ func (hsdb *HSDatabase) listNodes() ([]types.Node, error) {
|
||||||
return nodes, nil
|
return nodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListNodesByGivenName(givenName string) (types.Nodes, error) {
|
func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
return hsdb.listNodesByGivenName(givenName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, error) {
|
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := hsdb.db.
|
if err := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
|
@ -106,12 +97,15 @@ func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, err
|
||||||
return nodes, nil
|
return nodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNode finds a Node by name and user and returns the Node struct.
|
func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
|
||||||
func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||||
hsdb.mu.RLock()
|
return getNode(rx, user, name)
|
||||||
defer hsdb.mu.RUnlock()
|
})
|
||||||
|
}
|
||||||
|
|
||||||
nodes, err := hsdb.ListNodesByUser(user)
|
// getNode finds a Node by name and user and returns the Node struct.
|
||||||
|
func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
|
||||||
|
nodes, err := ListNodesByUser(tx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -125,34 +119,16 @@ func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) {
|
||||||
return nil, ErrNodeNotFound
|
return nil, ErrNodeNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodeByGivenName finds a Node by given name and user and returns the Node struct.
|
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
|
||||||
func (hsdb *HSDatabase) GetNodeByGivenName(
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||||
user string,
|
return GetNodeByID(rx, id)
|
||||||
givenName string,
|
})
|
||||||
) (*types.Node, error) {
|
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
node := types.Node{}
|
|
||||||
if err := hsdb.db.
|
|
||||||
Preload("AuthKey").
|
|
||||||
Preload("AuthKey.User").
|
|
||||||
Preload("User").
|
|
||||||
Preload("Routes").
|
|
||||||
Where("given_name = ?", givenName).First(&node).Error; err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ErrNodeNotFound
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodeByID finds a Node by ID and returns the Node struct.
|
// GetNodeByID finds a Node by ID and returns the Node struct.
|
||||||
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
|
func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
mach := types.Node{}
|
mach := types.Node{}
|
||||||
if result := hsdb.db.
|
if result := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
|
@ -164,21 +140,19 @@ func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
|
||||||
return &mach, nil
|
return &mach, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
|
func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
|
||||||
func (hsdb *HSDatabase) GetNodeByMachineKey(
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||||
machineKey key.MachinePublic,
|
return GetNodeByMachineKey(rx, machineKey)
|
||||||
) (*types.Node, error) {
|
})
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
return hsdb.getNodeByMachineKey(machineKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getNodeByMachineKey(
|
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
|
||||||
|
func GetNodeByMachineKey(
|
||||||
|
tx *gorm.DB,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*types.Node, error) {
|
) (*types.Node, error) {
|
||||||
mach := types.Node{}
|
mach := types.Node{}
|
||||||
if result := hsdb.db.
|
if result := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
|
@ -190,36 +164,24 @@ func (hsdb *HSDatabase) getNodeByMachineKey(
|
||||||
return &mach, nil
|
return &mach, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodeByNodeKey finds a Node by its current NodeKey.
|
func (hsdb *HSDatabase) GetNodeByAnyKey(
|
||||||
func (hsdb *HSDatabase) GetNodeByNodeKey(
|
machineKey key.MachinePublic,
|
||||||
nodeKey key.NodePublic,
|
nodeKey key.NodePublic,
|
||||||
|
oldNodeKey key.NodePublic,
|
||||||
) (*types.Node, error) {
|
) (*types.Node, error) {
|
||||||
hsdb.mu.RLock()
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||||
defer hsdb.mu.RUnlock()
|
return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey)
|
||||||
|
})
|
||||||
node := types.Node{}
|
|
||||||
if result := hsdb.db.
|
|
||||||
Preload("AuthKey").
|
|
||||||
Preload("AuthKey.User").
|
|
||||||
Preload("User").
|
|
||||||
Preload("Routes").
|
|
||||||
First(&node, "node_key = ?",
|
|
||||||
nodeKey.String()); result.Error != nil {
|
|
||||||
return nil, result.Error
|
|
||||||
}
|
|
||||||
|
|
||||||
return &node, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct.
|
// GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct.
|
||||||
func (hsdb *HSDatabase) GetNodeByAnyKey(
|
// TODO(kradalby): see if we can remove this.
|
||||||
|
func GetNodeByAnyKey(
|
||||||
|
tx *gorm.DB,
|
||||||
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
|
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
|
||||||
) (*types.Node, error) {
|
) (*types.Node, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
node := types.Node{}
|
node := types.Node{}
|
||||||
if result := hsdb.db.
|
if result := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
|
@ -234,60 +196,44 @@ func (hsdb *HSDatabase) GetNodeByAnyKey(
|
||||||
return &node, nil
|
return &node, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) NodeReloadFromDatabase(node *types.Node) error {
|
func (hsdb *HSDatabase) SetTags(
|
||||||
hsdb.mu.RLock()
|
nodeID uint64,
|
||||||
defer hsdb.mu.RUnlock()
|
tags []string,
|
||||||
|
) error {
|
||||||
if result := hsdb.db.Find(node).First(&node); result.Error != nil {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return result.Error
|
return SetTags(tx, nodeID, tags)
|
||||||
}
|
})
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTags takes a Node struct pointer and update the forced tags.
|
// SetTags takes a Node struct pointer and update the forced tags.
|
||||||
func (hsdb *HSDatabase) SetTags(
|
func SetTags(
|
||||||
node *types.Node,
|
tx *gorm.DB,
|
||||||
|
nodeID uint64,
|
||||||
tags []string,
|
tags []string,
|
||||||
) error {
|
) error {
|
||||||
hsdb.mu.Lock()
|
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
if len(tags) == 0 {
|
if len(tags) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
newTags := []string{}
|
newTags := types.StringList{}
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
if !util.StringOrPrefixListContains(newTags, tag) {
|
if !util.StringOrPrefixListContains(newTags, tag) {
|
||||||
newTags = append(newTags, tag)
|
newTags = append(newTags, tag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
|
||||||
ForcedTags: newTags,
|
|
||||||
}).Error; err != nil {
|
|
||||||
return fmt.Errorf("failed to update tags for node in the database: %w", err)
|
return fmt.Errorf("failed to update tags for node in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: types.Nodes{node},
|
|
||||||
Message: "called from db.SetTags",
|
|
||||||
}
|
|
||||||
if stateUpdate.Valid() {
|
|
||||||
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenameNode takes a Node struct and a new GivenName for the nodes
|
// RenameNode takes a Node struct and a new GivenName for the nodes
|
||||||
// and renames it.
|
// and renames it.
|
||||||
func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
|
func RenameNode(tx *gorm.DB,
|
||||||
hsdb.mu.Lock()
|
nodeID uint64, newName string,
|
||||||
defer hsdb.mu.Unlock()
|
) error {
|
||||||
|
|
||||||
err := util.CheckForFQDNRules(
|
err := util.CheckForFQDNRules(
|
||||||
newName,
|
newName,
|
||||||
)
|
)
|
||||||
|
@ -295,129 +241,74 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Str("func", "RenameNode").
|
Str("func", "RenameNode").
|
||||||
Str("node", node.Hostname).
|
Uint64("nodeID", nodeID).
|
||||||
Str("newName", newName).
|
Str("newName", newName).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("failed to rename node")
|
Msg("failed to rename node")
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
node.GivenName = newName
|
|
||||||
|
|
||||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
||||||
GivenName: newName,
|
|
||||||
}).Error; err != nil {
|
|
||||||
return fmt.Errorf("failed to rename node in the database: %w", err)
|
return fmt.Errorf("failed to rename node in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: types.Nodes{node},
|
|
||||||
Message: "called from db.RenameNode",
|
|
||||||
}
|
|
||||||
if stateUpdate.Valid() {
|
|
||||||
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error {
|
||||||
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
|
return NodeSetExpiry(tx, nodeID, expiry)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// NodeSetExpiry takes a Node struct and a new expiry time.
|
// NodeSetExpiry takes a Node struct and a new expiry time.
|
||||||
func (hsdb *HSDatabase) NodeSetExpiry(node *types.Node, expiry time.Time) error {
|
func NodeSetExpiry(tx *gorm.DB,
|
||||||
hsdb.mu.Lock()
|
nodeID uint64, expiry time.Time,
|
||||||
defer hsdb.mu.Unlock()
|
) error {
|
||||||
|
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
|
||||||
return hsdb.nodeSetExpiry(node, expiry)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error {
|
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.MachinePublic]bool) error {
|
||||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
Expiry: &expiry,
|
return DeleteNode(tx, node, isConnected)
|
||||||
}).Error; err != nil {
|
})
|
||||||
return fmt.Errorf(
|
|
||||||
"failed to refresh node (update expiration) in the database: %w",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
node.Expiry = &expiry
|
|
||||||
|
|
||||||
stateSelfUpdate := types.StateUpdate{
|
|
||||||
Type: types.StateSelfUpdate,
|
|
||||||
ChangeNodes: types.Nodes{node},
|
|
||||||
}
|
|
||||||
if stateSelfUpdate.Valid() {
|
|
||||||
hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: tailcfg.NodeID(node.ID),
|
|
||||||
KeyExpiry: &expiry,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if stateUpdate.Valid() {
|
|
||||||
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteNode deletes a Node from the database.
|
// DeleteNode deletes a Node from the database.
|
||||||
func (hsdb *HSDatabase) DeleteNode(node *types.Node) error {
|
// Caller is responsible for notifying all of change.
|
||||||
hsdb.mu.Lock()
|
func DeleteNode(tx *gorm.DB,
|
||||||
defer hsdb.mu.Unlock()
|
node *types.Node,
|
||||||
|
isConnected map[key.MachinePublic]bool,
|
||||||
return hsdb.deleteNode(node)
|
) error {
|
||||||
}
|
err := deleteNodeRoutes(tx, node, map[key.MachinePublic]bool{})
|
||||||
|
|
||||||
func (hsdb *HSDatabase) deleteNode(node *types.Node) error {
|
|
||||||
err := hsdb.deleteNodeRoutes(node)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unscoped causes the node to be fully removed from the database.
|
// Unscoped causes the node to be fully removed from the database.
|
||||||
if err := hsdb.db.Unscoped().Delete(&node).Error; err != nil {
|
if err := tx.Unscoped().Delete(&node).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
|
||||||
Type: types.StatePeerRemoved,
|
|
||||||
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
|
|
||||||
}
|
|
||||||
if stateUpdate.Valid() {
|
|
||||||
hsdb.notifier.NotifyAll(stateUpdate)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLastSeen sets a node's last seen field indicating that we
|
// UpdateLastSeen sets a node's last seen field indicating that we
|
||||||
// have recently communicating with this node.
|
// have recently communicating with this node.
|
||||||
// This is mostly used to indicate if a node is online and is not
|
func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error {
|
||||||
// extremely important to make sure is fully correct and to avoid
|
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
|
||||||
// holding up the hot path, does not contain any locks and isnt
|
|
||||||
// concurrency safe. But that should be ok.
|
|
||||||
func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error {
|
|
||||||
return hsdb.db.Model(node).Updates(types.Node{
|
|
||||||
LastSeen: node.LastSeen,
|
|
||||||
}).Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
func RegisterNodeFromAuthCallback(
|
||||||
|
tx *gorm.DB,
|
||||||
cache *cache.Cache,
|
cache *cache.Cache,
|
||||||
mkey key.MachinePublic,
|
mkey key.MachinePublic,
|
||||||
userName string,
|
userName string,
|
||||||
nodeExpiry *time.Time,
|
nodeExpiry *time.Time,
|
||||||
registrationMethod string,
|
registrationMethod string,
|
||||||
|
ipPrefixes []netip.Prefix,
|
||||||
) (*types.Node, error) {
|
) (*types.Node, error) {
|
||||||
hsdb.mu.Lock()
|
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("machine_key", mkey.ShortString()).
|
Str("machine_key", mkey.ShortString()).
|
||||||
Str("userName", userName).
|
Str("userName", userName).
|
||||||
|
@ -427,7 +318,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
||||||
|
|
||||||
if nodeInterface, ok := cache.Get(mkey.String()); ok {
|
if nodeInterface, ok := cache.Get(mkey.String()); ok {
|
||||||
if registrationNode, ok := nodeInterface.(types.Node); ok {
|
if registrationNode, ok := nodeInterface.(types.Node); ok {
|
||||||
user, err := hsdb.getUser(userName)
|
user, err := GetUser(tx, userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"failed to find user in register node from auth callback, %w",
|
"failed to find user in register node from auth callback, %w",
|
||||||
|
@ -442,14 +333,17 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
||||||
}
|
}
|
||||||
|
|
||||||
registrationNode.UserID = user.ID
|
registrationNode.UserID = user.ID
|
||||||
|
registrationNode.User = *user
|
||||||
registrationNode.RegisterMethod = registrationMethod
|
registrationNode.RegisterMethod = registrationMethod
|
||||||
|
|
||||||
if nodeExpiry != nil {
|
if nodeExpiry != nil {
|
||||||
registrationNode.Expiry = nodeExpiry
|
registrationNode.Expiry = nodeExpiry
|
||||||
}
|
}
|
||||||
|
|
||||||
node, err := hsdb.registerNode(
|
node, err := RegisterNode(
|
||||||
|
tx,
|
||||||
registrationNode,
|
registrationNode,
|
||||||
|
ipPrefixes,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -465,15 +359,14 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
||||||
return nil, ErrNodeNotFoundRegistrationCache
|
return nil, ErrNodeNotFoundRegistrationCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
|
|
||||||
func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
|
func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
|
||||||
hsdb.mu.Lock()
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
defer hsdb.mu.Unlock()
|
return RegisterNode(tx, node, hsdb.ipPrefixes)
|
||||||
|
})
|
||||||
return hsdb.registerNode(node)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
|
||||||
|
func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Str("machine_key", node.MachineKey.ShortString()).
|
Str("machine_key", node.MachineKey.ShortString()).
|
||||||
|
@ -485,7 +378,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
||||||
// so we store the node.Expire and node.Nodekey that has been set when
|
// so we store the node.Expire and node.Nodekey that has been set when
|
||||||
// adding it to the registrationCache
|
// adding it to the registrationCache
|
||||||
if len(node.IPAddresses) > 0 {
|
if len(node.IPAddresses) > 0 {
|
||||||
if err := hsdb.db.Save(&node).Error; err != nil {
|
if err := tx.Save(&node).Error; err != nil {
|
||||||
return nil, fmt.Errorf("failed register existing node in the database: %w", err)
|
return nil, fmt.Errorf("failed register existing node in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -500,10 +393,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
||||||
return &node, nil
|
return &node, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
hsdb.ipAllocationMutex.Lock()
|
ips, err := getAvailableIPs(tx, ipPrefixes)
|
||||||
defer hsdb.ipAllocationMutex.Unlock()
|
|
||||||
|
|
||||||
ips, err := hsdb.getAvailableIPs()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -516,7 +406,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
||||||
|
|
||||||
node.IPAddresses = ips
|
node.IPAddresses = ips
|
||||||
|
|
||||||
if err := hsdb.db.Save(&node).Error; err != nil {
|
if err := tx.Save(&node).Error; err != nil {
|
||||||
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
|
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -530,61 +420,50 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NodeSetNodeKey sets the node key of a node and saves it to the database.
|
// NodeSetNodeKey sets the node key of a node and saves it to the database.
|
||||||
func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic) error {
|
func NodeSetNodeKey(tx *gorm.DB, node *types.Node, nodeKey key.NodePublic) error {
|
||||||
hsdb.mu.Lock()
|
return tx.Model(node).Updates(types.Node{
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
|
||||||
NodeKey: nodeKey,
|
NodeKey: nodeKey,
|
||||||
}).Error; err != nil {
|
}).Error
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NodeSetMachineKey sets the node key of a node and saves it to the database.
|
|
||||||
func (hsdb *HSDatabase) NodeSetMachineKey(
|
func (hsdb *HSDatabase) NodeSetMachineKey(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) error {
|
) error {
|
||||||
hsdb.mu.Lock()
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
defer hsdb.mu.Unlock()
|
return NodeSetMachineKey(tx, node, machineKey)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
// NodeSetMachineKey sets the node key of a node and saves it to the database.
|
||||||
|
func NodeSetMachineKey(
|
||||||
|
tx *gorm.DB,
|
||||||
|
node *types.Node,
|
||||||
|
machineKey key.MachinePublic,
|
||||||
|
) error {
|
||||||
|
return tx.Model(node).Updates(types.Node{
|
||||||
MachineKey: machineKey,
|
MachineKey: machineKey,
|
||||||
}).Error; err != nil {
|
}).Error
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NodeSave saves a node object to the database, prefer to use a specific save method rather
|
// NodeSave saves a node object to the database, prefer to use a specific save method rather
|
||||||
// than this. It is intended to be used when we are changing or.
|
// than this. It is intended to be used when we are changing or.
|
||||||
func (hsdb *HSDatabase) NodeSave(node *types.Node) error {
|
// TODO(kradalby): Remove this func, just use Save.
|
||||||
hsdb.mu.Lock()
|
func NodeSave(tx *gorm.DB, node *types.Node) error {
|
||||||
defer hsdb.mu.Unlock()
|
return tx.Save(node).Error
|
||||||
|
}
|
||||||
|
|
||||||
if err := hsdb.db.Save(node).Error; err != nil {
|
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
|
||||||
return err
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
|
||||||
}
|
return GetAdvertisedRoutes(rx, node)
|
||||||
|
})
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAdvertisedRoutes returns the routes that are be advertised by the given node.
|
// GetAdvertisedRoutes returns the routes that are be advertised by the given node.
|
||||||
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
|
func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
return hsdb.getAdvertisedRoutes(node)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
|
|
||||||
routes := types.Routes{}
|
routes := types.Routes{}
|
||||||
|
|
||||||
err := hsdb.db.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error
|
Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
@ -605,18 +484,17 @@ func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, e
|
||||||
return prefixes, nil
|
return prefixes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEnabledRoutes returns the routes that are enabled for the node.
|
|
||||||
func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
|
func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
|
||||||
hsdb.mu.RLock()
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
|
||||||
defer hsdb.mu.RUnlock()
|
return GetEnabledRoutes(rx, node)
|
||||||
|
})
|
||||||
return hsdb.getEnabledRoutes(node)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
|
// GetEnabledRoutes returns the routes that are enabled for the node.
|
||||||
|
func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
|
||||||
routes := types.Routes{}
|
routes := types.Routes{}
|
||||||
|
|
||||||
err := hsdb.db.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true).
|
Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true).
|
||||||
Find(&routes).Error
|
Find(&routes).Error
|
||||||
|
@ -638,16 +516,13 @@ func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, erro
|
||||||
return prefixes, nil
|
return prefixes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool {
|
func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(routeStr)
|
route, err := netip.ParsePrefix(routeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
enabledRoutes, err := hsdb.getEnabledRoutes(node)
|
enabledRoutes, err := GetEnabledRoutes(tx, node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Could not get enabled routes")
|
log.Error().Err(err).Msg("Could not get enabled routes")
|
||||||
|
|
||||||
|
@ -663,26 +538,37 @@ func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) enableRoutes(
|
||||||
|
node *types.Node,
|
||||||
|
routeStrs ...string,
|
||||||
|
) (*types.StateUpdate, error) {
|
||||||
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||||
|
return enableRoutes(tx, node, routeStrs...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// enableRoutes enables new routes based on a list of new routes.
|
// enableRoutes enables new routes based on a list of new routes.
|
||||||
func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error {
|
func enableRoutes(tx *gorm.DB,
|
||||||
|
node *types.Node, routeStrs ...string,
|
||||||
|
) (*types.StateUpdate, error) {
|
||||||
newRoutes := make([]netip.Prefix, len(routeStrs))
|
newRoutes := make([]netip.Prefix, len(routeStrs))
|
||||||
for index, routeStr := range routeStrs {
|
for index, routeStr := range routeStrs {
|
||||||
route, err := netip.ParsePrefix(routeStr)
|
route, err := netip.ParsePrefix(routeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newRoutes[index] = route
|
newRoutes[index] = route
|
||||||
}
|
}
|
||||||
|
|
||||||
advertisedRoutes, err := hsdb.getAdvertisedRoutes(node)
|
advertisedRoutes, err := GetAdvertisedRoutes(tx, node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
for _, newRoute := range newRoutes {
|
||||||
if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
|
if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
|
||||||
return fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"route (%s) is not available on node %s: %w",
|
"route (%s) is not available on node %s: %w",
|
||||||
node.Hostname,
|
node.Hostname,
|
||||||
newRoute, ErrNodeRouteIsNotAvailable,
|
newRoute, ErrNodeRouteIsNotAvailable,
|
||||||
|
@ -693,7 +579,7 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
|
||||||
// Separate loop so we don't leave things in a half-updated state
|
// Separate loop so we don't leave things in a half-updated state
|
||||||
for _, prefix := range newRoutes {
|
for _, prefix := range newRoutes {
|
||||||
route := types.Route{}
|
route := types.Route{}
|
||||||
err := hsdb.db.Preload("Node").
|
err := tx.Preload("Node").
|
||||||
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
|
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
|
||||||
First(&route).Error
|
First(&route).Error
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -702,23 +588,23 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
|
||||||
// Mark already as primary if there is only this node offering this subnet
|
// Mark already as primary if there is only this node offering this subnet
|
||||||
// (and is not an exit route)
|
// (and is not an exit route)
|
||||||
if !route.IsExitRoute() {
|
if !route.IsExitRoute() {
|
||||||
route.IsPrimary = hsdb.isUniquePrefix(route)
|
route.IsPrimary = isUniquePrefix(tx, route)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = hsdb.db.Save(&route).Error
|
err = tx.Save(&route).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to enable route: %w", err)
|
return nil, fmt.Errorf("failed to enable route: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("failed to find route: %w", err)
|
return nil, fmt.Errorf("failed to find route: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the node has the latest routes when notifying the other
|
// Ensure the node has the latest routes when notifying the other
|
||||||
// nodes
|
// nodes
|
||||||
nRoutes, err := hsdb.getNodeRoutes(node)
|
nRoutes, err := GetNodeRoutes(tx, node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read back routes: %w", err)
|
return nil, fmt.Errorf("failed to read back routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
node.Routes = nRoutes
|
node.Routes = nRoutes
|
||||||
|
@ -729,17 +615,11 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
|
||||||
Strs("routes", routeStrs).
|
Strs("routes", routeStrs).
|
||||||
Msg("enabling routes")
|
Msg("enabling routes")
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
return &types.StateUpdate{
|
||||||
Type: types.StatePeerChanged,
|
Type: types.StatePeerChanged,
|
||||||
ChangeNodes: types.Nodes{node},
|
ChangeNodes: types.Nodes{node},
|
||||||
Message: "called from db.enableRoutes",
|
Message: "created in db.enableRoutes",
|
||||||
}
|
}, nil
|
||||||
if stateUpdate.Valid() {
|
|
||||||
hsdb.notifier.NotifyWithIgnore(
|
|
||||||
stateUpdate, node.MachineKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||||
|
@ -772,16 +652,23 @@ func (hsdb *HSDatabase) GenerateGivenName(
|
||||||
mkey key.MachinePublic,
|
mkey key.MachinePublic,
|
||||||
suppliedName string,
|
suppliedName string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
hsdb.mu.RLock()
|
return Read(hsdb.DB, func(rx *gorm.DB) (string, error) {
|
||||||
defer hsdb.mu.RUnlock()
|
return GenerateGivenName(rx, mkey, suppliedName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateGivenName(
|
||||||
|
tx *gorm.DB,
|
||||||
|
mkey key.MachinePublic,
|
||||||
|
suppliedName string,
|
||||||
|
) (string, error) {
|
||||||
givenName, err := generateGivenName(suppliedName, false)
|
givenName, err := generateGivenName(suppliedName, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
|
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
|
||||||
nodes, err := hsdb.listNodesByGivenName(givenName)
|
nodes, err := listNodesByGivenName(tx, givenName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -805,29 +692,28 @@ func (hsdb *HSDatabase) GenerateGivenName(
|
||||||
return givenName, nil
|
return givenName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) {
|
func ExpireEphemeralNodes(tx *gorm.DB,
|
||||||
hsdb.mu.Lock()
|
inactivityThreshhold time.Duration,
|
||||||
defer hsdb.mu.Unlock()
|
) (types.StateUpdate, bool) {
|
||||||
|
users, err := ListUsers(tx)
|
||||||
users, err := hsdb.listUsers()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error listing users")
|
log.Error().Err(err).Msg("Error listing users")
|
||||||
|
|
||||||
return
|
return types.StateUpdate{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expired := make([]tailcfg.NodeID, 0)
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
nodes, err := hsdb.listNodesByUser(user.Name)
|
nodes, err := ListNodesByUser(tx, user.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("user", user.Name).
|
Str("user", user.Name).
|
||||||
Msg("Error listing nodes in user")
|
Msg("Error listing nodes in user")
|
||||||
|
|
||||||
return
|
return types.StateUpdate{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
expired := make([]tailcfg.NodeID, 0)
|
|
||||||
for idx, node := range nodes {
|
for idx, node := range nodes {
|
||||||
if node.IsEphemeral() && node.LastSeen != nil &&
|
if node.IsEphemeral() && node.LastSeen != nil &&
|
||||||
time.Now().
|
time.Now().
|
||||||
|
@ -838,7 +724,8 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration)
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Msg("Ephemeral client removed from database")
|
Msg("Ephemeral client removed from database")
|
||||||
|
|
||||||
err = hsdb.deleteNode(nodes[idx])
|
// empty isConnected map as ephemeral nodes are not routes
|
||||||
|
err = DeleteNode(tx, nodes[idx], map[key.MachinePublic]bool{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -848,33 +735,35 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(expired) > 0 {
|
// TODO(kradalby): needs to be moved out of transaction
|
||||||
hsdb.notifier.NotifyAll(types.StateUpdate{
|
|
||||||
Type: types.StatePeerRemoved,
|
|
||||||
Removed: expired,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if len(expired) > 0 {
|
||||||
|
return types.StateUpdate{
|
||||||
|
Type: types.StatePeerRemoved,
|
||||||
|
Removed: expired,
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return types.StateUpdate{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
|
func ExpireExpiredNodes(tx *gorm.DB,
|
||||||
hsdb.mu.Lock()
|
lastCheck time.Time,
|
||||||
defer hsdb.mu.Unlock()
|
) (time.Time, types.StateUpdate, bool) {
|
||||||
|
|
||||||
// use the time of the start of the function to ensure we
|
// use the time of the start of the function to ensure we
|
||||||
// dont miss some nodes by returning it _after_ we have
|
// dont miss some nodes by returning it _after_ we have
|
||||||
// checked everything.
|
// checked everything.
|
||||||
started := time.Now()
|
started := time.Now()
|
||||||
|
|
||||||
expiredNodes := make([]*types.Node, 0)
|
expired := make([]*tailcfg.PeerChange, 0)
|
||||||
|
|
||||||
nodes, err := hsdb.listNodes()
|
nodes, err := ListNodes(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Error listing nodes to find expired nodes")
|
Msg("Error listing nodes to find expired nodes")
|
||||||
|
|
||||||
return time.Unix(0, 0)
|
return time.Unix(0, 0), types.StateUpdate{}, false
|
||||||
}
|
}
|
||||||
for index, node := range nodes {
|
for index, node := range nodes {
|
||||||
if node.IsExpired() &&
|
if node.IsExpired() &&
|
||||||
|
@ -882,13 +771,17 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
|
||||||
// It will notify about all nodes that has been expired.
|
// It will notify about all nodes that has been expired.
|
||||||
// It should only notify about expired nodes since _last check_.
|
// It should only notify about expired nodes since _last check_.
|
||||||
node.Expiry.After(lastCheck) {
|
node.Expiry.After(lastCheck) {
|
||||||
expiredNodes = append(expiredNodes, &nodes[index])
|
expired = append(expired, &tailcfg.PeerChange{
|
||||||
|
NodeID: tailcfg.NodeID(node.ID),
|
||||||
|
KeyExpiry: node.Expiry,
|
||||||
|
})
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
// Do not use setNodeExpiry as that has a notifier hook, which
|
// Do not use setNodeExpiry as that has a notifier hook, which
|
||||||
// can cause a deadlock, we are updating all changed nodes later
|
// can cause a deadlock, we are updating all changed nodes later
|
||||||
// and there is no point in notifiying twice.
|
// and there is no point in notifiying twice.
|
||||||
if err := hsdb.db.Model(nodes[index]).Updates(types.Node{
|
if err := tx.Model(&nodes[index]).Updates(types.Node{
|
||||||
Expiry: &started,
|
Expiry: &now,
|
||||||
}).Error; err != nil {
|
}).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -904,33 +797,12 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
expired := make([]*tailcfg.PeerChange, len(expiredNodes))
|
if len(expired) > 0 {
|
||||||
for idx, node := range expiredNodes {
|
return started, types.StateUpdate{
|
||||||
expired[idx] = &tailcfg.PeerChange{
|
Type: types.StatePeerChangedPatch,
|
||||||
NodeID: tailcfg.NodeID(node.ID),
|
ChangePatches: expired,
|
||||||
KeyExpiry: &started,
|
}, true
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Inform the peers of a node with a lightweight update.
|
return started, types.StateUpdate{}, false
|
||||||
stateUpdate := types.StateUpdate{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: expired,
|
|
||||||
}
|
|
||||||
if stateUpdate.Valid() {
|
|
||||||
hsdb.notifier.NotifyAll(stateUpdate)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Inform the node itself that it has expired.
|
|
||||||
for _, node := range expiredNodes {
|
|
||||||
stateSelfUpdate := types.StateUpdate{
|
|
||||||
Type: types.StateSelfUpdate,
|
|
||||||
ChangeNodes: types.Nodes{node},
|
|
||||||
}
|
|
||||||
if stateSelfUpdate.Valid() {
|
|
||||||
hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return started
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ func (s *Suite) TestGetNode(c *check.C) {
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("test", "testnode")
|
_, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -38,9 +38,9 @@ func (s *Suite) TestGetNode(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(node)
|
db.DB.Save(node)
|
||||||
|
|
||||||
_, err = db.GetNode("test", "testnode")
|
_, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,40 +66,12 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
_, err = db.GetNodeByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetNodeByNodeKey(c *check.C) {
|
|
||||||
user, err := db.CreateUser("test")
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
|
||||||
c.Assert(err, check.NotNil)
|
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
|
||||||
machineKey := key.NewMachine()
|
|
||||||
|
|
||||||
node := types.Node{
|
|
||||||
ID: 0,
|
|
||||||
MachineKey: machineKey.Public(),
|
|
||||||
NodeKey: nodeKey.Public(),
|
|
||||||
Hostname: "testnode",
|
|
||||||
UserID: user.ID,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
AuthKeyID: uint(pak.ID),
|
|
||||||
}
|
|
||||||
db.db.Save(&node)
|
|
||||||
|
|
||||||
_, err = db.GetNodeByNodeKey(nodeKey.Public())
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
@ -124,7 +96,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
_, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
_, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
@ -146,12 +118,12 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(1),
|
AuthKeyID: uint(1),
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
err = db.DeleteNode(&node)
|
err = db.DeleteNode(&node, map[key.MachinePublic]bool{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode(user.Name, "testnode3")
|
_, err = db.getNode(user.Name, "testnode3")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,7 +150,7 @@ func (s *Suite) TestListPeers(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
}
|
}
|
||||||
|
|
||||||
node0ByID, err := db.GetNodeByID(0)
|
node0ByID, err := db.GetNodeByID(0)
|
||||||
|
@ -228,7 +200,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(stor[index%2].key.ID),
|
AuthKeyID: uint(stor[index%2].key.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
}
|
}
|
||||||
|
|
||||||
aclPolicy := &policy.ACLPolicy{
|
aclPolicy := &policy.ACLPolicy{
|
||||||
|
@ -295,7 +267,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("test", "testnode")
|
_, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -311,16 +283,19 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
Expiry: &time.Time{},
|
Expiry: &time.Time{},
|
||||||
}
|
}
|
||||||
db.db.Save(node)
|
db.DB.Save(node)
|
||||||
|
|
||||||
nodeFromDB, err := db.GetNode("test", "testnode")
|
nodeFromDB, err := db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(nodeFromDB, check.NotNil)
|
c.Assert(nodeFromDB, check.NotNil)
|
||||||
|
|
||||||
c.Assert(nodeFromDB.IsExpired(), check.Equals, false)
|
c.Assert(nodeFromDB.IsExpired(), check.Equals, false)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
err = db.NodeSetExpiry(nodeFromDB, now)
|
err = db.NodeSetExpiry(nodeFromDB.ID, now)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
nodeFromDB, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
|
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
|
||||||
|
@ -354,7 +329,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||||
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("user-1", "testnode")
|
_, err = db.getNode("user-1", "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -372,7 +347,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(node)
|
db.DB.Save(node)
|
||||||
|
|
||||||
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
|
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
|
||||||
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
|
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
|
||||||
|
@ -397,7 +372,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("test", "testnode")
|
_, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -412,21 +387,21 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(node)
|
db.DB.Save(node)
|
||||||
|
|
||||||
// assign simple tags
|
// assign simple tags
|
||||||
sTags := []string{"tag:test", "tag:foo"}
|
sTags := []string{"tag:test", "tag:foo"}
|
||||||
err = db.SetTags(node, sTags)
|
err = db.SetTags(node.ID, sTags)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.GetNode("test", "testnode")
|
node, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
|
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
|
||||||
|
|
||||||
// assign duplicat tags, expect no errors but no doubles in DB
|
// assign duplicat tags, expect no errors but no doubles in DB
|
||||||
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||||
err = db.SetTags(node, eTags)
|
err = db.SetTags(node.ID, eTags)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.GetNode("test", "testnode")
|
node, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(
|
c.Assert(
|
||||||
node.ForcedTags,
|
node.ForcedTags,
|
||||||
|
@ -601,7 +576,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||||
}
|
}
|
||||||
|
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
sendUpdate, err := db.SaveNodeRoutes(&node)
|
sendUpdate, err := db.SaveNodeRoutes(&node)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
@ -610,7 +585,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
||||||
node0ByID, err := db.GetNodeByID(0)
|
node0ByID, err := db.GetNodeByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = db.EnableAutoApprovedRoutes(pol, node0ByID)
|
// TODO(kradalby): Check state update
|
||||||
|
_, err = db.EnableAutoApprovedRoutes(pol, node0ByID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes, err := db.GetEnabledRoutes(node0ByID)
|
enabledRoutes, err := db.GetEnabledRoutes(node0ByID)
|
||||||
|
|
|
@ -20,7 +20,6 @@ var (
|
||||||
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
|
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
|
||||||
func (hsdb *HSDatabase) CreatePreAuthKey(
|
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
userName string,
|
userName string,
|
||||||
reusable bool,
|
reusable bool,
|
||||||
|
@ -28,11 +27,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
aclTags []string,
|
aclTags []string,
|
||||||
) (*types.PreAuthKey, error) {
|
) (*types.PreAuthKey, error) {
|
||||||
// TODO(kradalby): figure out this lock
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
|
||||||
// hsdb.mu.Lock()
|
return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags)
|
||||||
// defer hsdb.mu.Unlock()
|
})
|
||||||
|
}
|
||||||
|
|
||||||
user, err := hsdb.GetUser(userName)
|
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||||
|
func CreatePreAuthKey(
|
||||||
|
tx *gorm.DB,
|
||||||
|
userName string,
|
||||||
|
reusable bool,
|
||||||
|
ephemeral bool,
|
||||||
|
expiration *time.Time,
|
||||||
|
aclTags []string,
|
||||||
|
) (*types.PreAuthKey, error) {
|
||||||
|
user, err := GetUser(tx, userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -48,7 +57,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
kstr, err := hsdb.generateKey()
|
kstr, err := generateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -63,29 +72,25 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
Expiration: expiration,
|
Expiration: expiration,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = hsdb.db.Transaction(func(db *gorm.DB) error {
|
if err := tx.Save(&key).Error; err != nil {
|
||||||
if err := db.Save(&key).Error; err != nil {
|
return nil, fmt.Errorf("failed to create key in the database: %w", err)
|
||||||
return fmt.Errorf("failed to create key in the database: %w", err)
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(aclTags) > 0 {
|
if len(aclTags) > 0 {
|
||||||
seenTags := map[string]bool{}
|
seenTags := map[string]bool{}
|
||||||
|
|
||||||
for _, tag := range aclTags {
|
for _, tag := range aclTags {
|
||||||
if !seenTags[tag] {
|
if !seenTags[tag] {
|
||||||
if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||||
return fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"failed to ceate key tag in the database: %w",
|
"failed to ceate key tag in the database: %w",
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
}
|
|
||||||
seenTags[tag] = true
|
|
||||||
}
|
}
|
||||||
|
seenTags[tag] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -94,22 +99,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
return &key, nil
|
return &key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
|
||||||
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
||||||
hsdb.mu.RLock()
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||||
defer hsdb.mu.RUnlock()
|
return ListPreAuthKeys(rx, userName)
|
||||||
|
})
|
||||||
return hsdb.listPreAuthKeys(userName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||||
user, err := hsdb.getUser(userName)
|
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
|
||||||
|
user, err := GetUser(tx, userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := []types.PreAuthKey{}
|
keys := []types.PreAuthKey{}
|
||||||
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
if err := tx.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,11 +121,8 @@ func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, er
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPreAuthKey returns a PreAuthKey for a given key.
|
// GetPreAuthKey returns a PreAuthKey for a given key.
|
||||||
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) {
|
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
|
||||||
hsdb.mu.RLock()
|
pak, err := ValidatePreAuthKey(tx, key)
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
pak, err := hsdb.ValidatePreAuthKey(key)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -135,15 +136,8 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe
|
||||||
|
|
||||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||||
// does not exist.
|
// does not exist.
|
||||||
func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error {
|
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
|
||||||
hsdb.mu.Lock()
|
return tx.Transaction(func(db *gorm.DB) error {
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
return hsdb.destroyPreAuthKey(pak)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
|
|
||||||
return hsdb.db.Transaction(func(db *gorm.DB) error {
|
|
||||||
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
|
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
@ -156,12 +150,15 @@ func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
|
||||||
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
||||||
hsdb.mu.Lock()
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
defer hsdb.mu.Unlock()
|
return ExpirePreAuthKey(tx, k)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||||
|
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||||
|
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,26 +166,26 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsePreAuthKey marks a PreAuthKey as used.
|
// UsePreAuthKey marks a PreAuthKey as used.
|
||||||
func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error {
|
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||||
hsdb.mu.Lock()
|
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
k.Used = true
|
k.Used = true
|
||||||
if err := hsdb.db.Save(k).Error; err != nil {
|
if err := tx.Save(k).Error; err != nil {
|
||||||
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
|
||||||
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
|
||||||
|
return ValidatePreAuthKey(rx, k)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
|
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||||
// If returns no error and a PreAuthKey, it can be used.
|
// If returns no error and a PreAuthKey, it can be used.
|
||||||
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
|
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
pak := types.PreAuthKey{}
|
pak := types.PreAuthKey{}
|
||||||
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
if result := tx.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
) {
|
) {
|
||||||
|
@ -204,7 +201,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := hsdb.db.
|
if err := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Where(&types.Node{AuthKeyID: uint(pak.ID)}).
|
Where(&types.Node{AuthKeyID: uint(pak.ID)}).
|
||||||
Find(&nodes).Error; err != nil {
|
Find(&nodes).Error; err != nil {
|
||||||
|
@ -218,7 +215,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
|
||||||
return &pak, nil
|
return &pak, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) generateKey() (string, error) {
|
func generateKey() (string, error) {
|
||||||
size := 24
|
size := 24
|
||||||
bytes := make([]byte, size)
|
bytes := make([]byte, size)
|
||||||
if _, err := rand.Read(bytes); err != nil {
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
|
@ -41,7 +42,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test2")
|
user, err := db.CreateUser("test2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now().Add(-5 * time.Second)
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
@ -82,7 +83,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||||
|
@ -103,7 +104,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
@ -138,19 +139,22 @@ func (*Suite) TestEphemeralKey(c *check.C) {
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
_, err = db.ValidatePreAuthKey(pak.Key)
|
_, err = db.ValidatePreAuthKey(pak.Key)
|
||||||
// Ephemeral keys are by definition reusable
|
// Ephemeral keys are by definition reusable
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("test7", "testest")
|
_, err = db.getNode("test7", "testest")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
db.ExpireEphemeralNodes(time.Second * 20)
|
db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
ExpireEphemeralNodes(tx, time.Second*20)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
// The machine record should have been deleted
|
// The machine record should have been deleted
|
||||||
_, err = db.GetNode("test7", "testest")
|
_, err = db.getNode("test7", "testest")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,7 +182,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
pak.Used = true
|
pak.Used = true
|
||||||
db.db.Save(&pak)
|
db.DB.Save(&pak)
|
||||||
|
|
||||||
_, err = db.ValidatePreAuthKey(pak.Key)
|
_, err = db.ValidatePreAuthKey(pak.Key)
|
||||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||||
|
|
|
@ -7,23 +7,15 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/samber/lo"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
|
func GetRoutes(tx *gorm.DB) (types.Routes, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
return hsdb.getRoutes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
|
|
||||||
var routes types.Routes
|
var routes types.Routes
|
||||||
err := hsdb.db.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Preload("Node.User").
|
Preload("Node.User").
|
||||||
Find(&routes).Error
|
Find(&routes).Error
|
||||||
|
@ -34,9 +26,9 @@ func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
|
func getAdvertisedAndEnabledRoutes(tx *gorm.DB) (types.Routes, error) {
|
||||||
var routes types.Routes
|
var routes types.Routes
|
||||||
err := hsdb.db.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Preload("Node.User").
|
Preload("Node.User").
|
||||||
Where("advertised = ? AND enabled = ?", true, true).
|
Where("advertised = ? AND enabled = ?", true, true).
|
||||||
|
@ -48,9 +40,9 @@ func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) {
|
func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
|
||||||
var routes types.Routes
|
var routes types.Routes
|
||||||
err := hsdb.db.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Preload("Node.User").
|
Preload("Node.User").
|
||||||
Where("prefix = ?", types.IPPrefix(pref)).
|
Where("prefix = ?", types.IPPrefix(pref)).
|
||||||
|
@ -62,16 +54,9 @@ func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, erro
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) {
|
func GetNodeAdvertisedRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
return hsdb.getNodeAdvertisedRoutes(node)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) {
|
|
||||||
var routes types.Routes
|
var routes types.Routes
|
||||||
err := hsdb.db.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Preload("Node.User").
|
Preload("Node.User").
|
||||||
Where("node_id = ? AND advertised = true", node.ID).
|
Where("node_id = ? AND advertised = true", node.ID).
|
||||||
|
@ -84,15 +69,14 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) {
|
func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) {
|
||||||
hsdb.mu.RLock()
|
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||||
defer hsdb.mu.RUnlock()
|
return GetNodeRoutes(rx, node)
|
||||||
|
})
|
||||||
return hsdb.getNodeRoutes(node)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
|
func GetNodeRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||||
var routes types.Routes
|
var routes types.Routes
|
||||||
err := hsdb.db.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Preload("Node.User").
|
Preload("Node.User").
|
||||||
Where("node_id = ?", node.ID).
|
Where("node_id = ?", node.ID).
|
||||||
|
@ -104,16 +88,9 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
|
func GetRoute(tx *gorm.DB, id uint64) (*types.Route, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
return hsdb.getRoute(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
|
|
||||||
var route types.Route
|
var route types.Route
|
||||||
err := hsdb.db.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Preload("Node.User").
|
Preload("Node.User").
|
||||||
First(&route, id).Error
|
First(&route, id).Error
|
||||||
|
@ -124,40 +101,34 @@ func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
|
||||||
return &route, nil
|
return &route, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) EnableRoute(id uint64) error {
|
func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
|
||||||
hsdb.mu.Lock()
|
route, err := GetRoute(tx, id)
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
return hsdb.enableRoute(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) enableRoute(id uint64) error {
|
|
||||||
route, err := hsdb.getRoute(id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||||
// be enabled at the same time, as per
|
// be enabled at the same time, as per
|
||||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||||
if route.IsExitRoute() {
|
if route.IsExitRoute() {
|
||||||
return hsdb.enableRoutes(
|
return enableRoutes(
|
||||||
|
tx,
|
||||||
&route.Node,
|
&route.Node,
|
||||||
types.ExitRouteV4.String(),
|
types.ExitRouteV4.String(),
|
||||||
types.ExitRouteV6.String(),
|
types.ExitRouteV6.String(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return hsdb.enableRoutes(&route.Node, netip.Prefix(route.Prefix).String())
|
return enableRoutes(tx, &route.Node, netip.Prefix(route.Prefix).String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
func DisableRoute(tx *gorm.DB,
|
||||||
hsdb.mu.Lock()
|
id uint64,
|
||||||
defer hsdb.mu.Unlock()
|
isConnected map[key.MachinePublic]bool,
|
||||||
|
) (*types.StateUpdate, error) {
|
||||||
route, err := hsdb.getRoute(id)
|
route, err := GetRoute(tx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var routes types.Routes
|
var routes types.Routes
|
||||||
|
@ -166,64 +137,79 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||||
// be enabled at the same time, as per
|
// be enabled at the same time, as per
|
||||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||||
|
var update *types.StateUpdate
|
||||||
if !route.IsExitRoute() {
|
if !route.IsExitRoute() {
|
||||||
err = hsdb.failoverRouteWithNotify(route)
|
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
route.Enabled = false
|
route.Enabled = false
|
||||||
route.IsPrimary = false
|
route.IsPrimary = false
|
||||||
err = hsdb.db.Save(route).Error
|
err = tx.Save(route).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
routes, err = hsdb.getNodeRoutes(&node)
|
routes, err = GetNodeRoutes(tx, &node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range routes {
|
for i := range routes {
|
||||||
if routes[i].IsExitRoute() {
|
if routes[i].IsExitRoute() {
|
||||||
routes[i].Enabled = false
|
routes[i].Enabled = false
|
||||||
routes[i].IsPrimary = false
|
routes[i].IsPrimary = false
|
||||||
err = hsdb.db.Save(&routes[i]).Error
|
err = tx.Save(&routes[i]).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if routes == nil {
|
if routes == nil {
|
||||||
routes, err = hsdb.getNodeRoutes(&node)
|
routes, err = GetNodeRoutes(tx, &node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
node.Routes = routes
|
node.Routes = routes
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
// If update is empty, it means that one was not created
|
||||||
Type: types.StatePeerChanged,
|
// by failover (as a failover was not necessary), create
|
||||||
ChangeNodes: types.Nodes{&node},
|
// one and return to the caller.
|
||||||
Message: "called from db.DisableRoute",
|
if update == nil {
|
||||||
}
|
update = &types.StateUpdate{
|
||||||
if stateUpdate.Valid() {
|
Type: types.StatePeerChanged,
|
||||||
hsdb.notifier.NotifyAll(stateUpdate)
|
ChangeNodes: types.Nodes{
|
||||||
|
&node,
|
||||||
|
},
|
||||||
|
Message: "called from db.DisableRoute",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return update, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
func (hsdb *HSDatabase) DeleteRoute(
|
||||||
hsdb.mu.Lock()
|
id uint64,
|
||||||
defer hsdb.mu.Unlock()
|
isConnected map[key.MachinePublic]bool,
|
||||||
|
) (*types.StateUpdate, error) {
|
||||||
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||||
|
return DeleteRoute(tx, id, isConnected)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
route, err := hsdb.getRoute(id)
|
func DeleteRoute(
|
||||||
|
tx *gorm.DB,
|
||||||
|
id uint64,
|
||||||
|
isConnected map[key.MachinePublic]bool,
|
||||||
|
) (*types.StateUpdate, error) {
|
||||||
|
route, err := GetRoute(tx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var routes types.Routes
|
var routes types.Routes
|
||||||
|
@ -232,19 +218,20 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||||
// be enabled at the same time, as per
|
// be enabled at the same time, as per
|
||||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||||
|
var update *types.StateUpdate
|
||||||
if !route.IsExitRoute() {
|
if !route.IsExitRoute() {
|
||||||
err := hsdb.failoverRouteWithNotify(route)
|
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
|
if err := tx.Unscoped().Delete(&route).Error; err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
routes, err := hsdb.getNodeRoutes(&node)
|
routes, err := GetNodeRoutes(tx, &node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
routesToDelete := types.Routes{}
|
routesToDelete := types.Routes{}
|
||||||
|
@ -254,56 +241,59 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
|
if err := tx.Unscoped().Delete(&routesToDelete).Error; err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If update is empty, it means that one was not created
|
||||||
|
// by failover (as a failover was not necessary), create
|
||||||
|
// one and return to the caller.
|
||||||
if routes == nil {
|
if routes == nil {
|
||||||
routes, err = hsdb.getNodeRoutes(&node)
|
routes, err = GetNodeRoutes(tx, &node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
node.Routes = routes
|
node.Routes = routes
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
if update == nil {
|
||||||
Type: types.StatePeerChanged,
|
update = &types.StateUpdate{
|
||||||
ChangeNodes: types.Nodes{&node},
|
Type: types.StatePeerChanged,
|
||||||
Message: "called from db.DeleteRoute",
|
ChangeNodes: types.Nodes{
|
||||||
}
|
&node,
|
||||||
if stateUpdate.Valid() {
|
},
|
||||||
hsdb.notifier.NotifyAll(stateUpdate)
|
Message: "called from db.DeleteRoute",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return update, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error {
|
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error {
|
||||||
routes, err := hsdb.getNodeRoutes(node)
|
routes, err := GetNodeRoutes(tx, node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range routes {
|
for i := range routes {
|
||||||
if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil {
|
if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): This is a bit too aggressive, we could probably
|
// TODO(kradalby): This is a bit too aggressive, we could probably
|
||||||
// figure out which routes needs to be failed over rather than all.
|
// figure out which routes needs to be failed over rather than all.
|
||||||
hsdb.failoverRouteWithNotify(&routes[i])
|
failoverRouteReturnUpdate(tx, isConnected, &routes[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isUniquePrefix returns if there is another node providing the same route already.
|
// isUniquePrefix returns if there is another node providing the same route already.
|
||||||
func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
|
func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
|
||||||
var count int64
|
var count int64
|
||||||
hsdb.db.
|
tx.Model(&types.Route{}).
|
||||||
Model(&types.Route{}).
|
|
||||||
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
|
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
|
||||||
route.Prefix,
|
route.Prefix,
|
||||||
route.NodeID,
|
route.NodeID,
|
||||||
|
@ -312,9 +302,9 @@ func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
|
||||||
return count == 0
|
return count == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) {
|
func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
|
||||||
var route types.Route
|
var route types.Route
|
||||||
err := hsdb.db.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
|
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
|
||||||
First(&route).Error
|
First(&route).Error
|
||||||
|
@ -329,14 +319,17 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro
|
||||||
return &route, nil
|
return &route, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
|
||||||
|
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||||
|
return GetNodePrimaryRoutes(rx, node)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
// getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
||||||
// Exit nodes are not considered for this, as they are never marked as Primary.
|
// Exit nodes are not considered for this, as they are never marked as Primary.
|
||||||
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
|
func GetNodePrimaryRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
var routes types.Routes
|
var routes types.Routes
|
||||||
err := hsdb.db.
|
err := tx.
|
||||||
Preload("Node").
|
Preload("Node").
|
||||||
Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true).
|
Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true).
|
||||||
Find(&routes).Error
|
Find(&routes).Error
|
||||||
|
@ -347,22 +340,21 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveNodeRoutes takes a node and updates the database with
|
|
||||||
// the new routes.
|
|
||||||
// It returns a bool wheter an update should be sent as the
|
|
||||||
// saved route impacts nodes.
|
|
||||||
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
|
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
|
||||||
hsdb.mu.Lock()
|
return Write(hsdb.DB, func(tx *gorm.DB) (bool, error) {
|
||||||
defer hsdb.mu.Unlock()
|
return SaveNodeRoutes(tx, node)
|
||||||
|
})
|
||||||
return hsdb.saveNodeRoutes(node)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
// SaveNodeRoutes takes a node and updates the database with
|
||||||
|
// the new routes.
|
||||||
|
// It returns a bool whether an update should be sent as the
|
||||||
|
// saved route impacts nodes.
|
||||||
|
func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
|
||||||
sendUpdate := false
|
sendUpdate := false
|
||||||
|
|
||||||
currentRoutes := types.Routes{}
|
currentRoutes := types.Routes{}
|
||||||
err := hsdb.db.Where("node_id = ?", node.ID).Find(¤tRoutes).Error
|
err := tx.Where("node_id = ?", node.ID).Find(¤tRoutes).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return sendUpdate, err
|
return sendUpdate, err
|
||||||
}
|
}
|
||||||
|
@ -382,7 +374,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
||||||
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
|
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
|
||||||
if !route.Advertised {
|
if !route.Advertised {
|
||||||
currentRoutes[pos].Advertised = true
|
currentRoutes[pos].Advertised = true
|
||||||
err := hsdb.db.Save(¤tRoutes[pos]).Error
|
err := tx.Save(¤tRoutes[pos]).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return sendUpdate, err
|
return sendUpdate, err
|
||||||
}
|
}
|
||||||
|
@ -398,7 +390,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
||||||
} else if route.Advertised {
|
} else if route.Advertised {
|
||||||
currentRoutes[pos].Advertised = false
|
currentRoutes[pos].Advertised = false
|
||||||
currentRoutes[pos].Enabled = false
|
currentRoutes[pos].Enabled = false
|
||||||
err := hsdb.db.Save(¤tRoutes[pos]).Error
|
err := tx.Save(¤tRoutes[pos]).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return sendUpdate, err
|
return sendUpdate, err
|
||||||
}
|
}
|
||||||
|
@ -413,7 +405,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
}
|
}
|
||||||
err := hsdb.db.Create(&route).Error
|
err := tx.Create(&route).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return sendUpdate, err
|
return sendUpdate, err
|
||||||
}
|
}
|
||||||
|
@ -425,127 +417,89 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
|
||||||
|
|
||||||
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route
|
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route
|
||||||
// currently have a functioning host that exposes the network.
|
// currently have a functioning host that exposes the network.
|
||||||
func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error {
|
func EnsureFailoverRouteIsAvailable(
|
||||||
nodeRoutes, err := hsdb.getNodeRoutes(node)
|
tx *gorm.DB,
|
||||||
|
isConnected map[key.MachinePublic]bool,
|
||||||
|
node *types.Node,
|
||||||
|
) (*types.StateUpdate, error) {
|
||||||
|
nodeRoutes, err := GetNodeRoutes(tx, node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var changedNodes types.Nodes
|
||||||
for _, nodeRoute := range nodeRoutes {
|
for _, nodeRoute := range nodeRoutes {
|
||||||
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix))
|
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.IsPrimary {
|
if route.IsPrimary {
|
||||||
// if we have a primary route, and the node is connected
|
// if we have a primary route, and the node is connected
|
||||||
// nothing needs to be done.
|
// nothing needs to be done.
|
||||||
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
|
if isConnected[route.Node.MachineKey] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// if not, we need to failover the route
|
// if not, we need to failover the route
|
||||||
err := hsdb.failoverRouteWithNotify(&route)
|
update, err := failoverRouteReturnUpdate(tx, isConnected, &route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if update != nil {
|
||||||
|
changedNodes = append(changedNodes, update.ChangeNodes...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
if len(changedNodes) != 0 {
|
||||||
}
|
return &types.StateUpdate{
|
||||||
|
|
||||||
func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error {
|
|
||||||
routes, err := hsdb.getNodeRoutes(node)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var changedKeys []key.MachinePublic
|
|
||||||
|
|
||||||
for _, route := range routes {
|
|
||||||
changed, err := hsdb.failoverRoute(&route)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
changedKeys = append(changedKeys, changed...)
|
|
||||||
}
|
|
||||||
|
|
||||||
changedKeys = lo.Uniq(changedKeys)
|
|
||||||
|
|
||||||
var nodes types.Nodes
|
|
||||||
|
|
||||||
for _, key := range changedKeys {
|
|
||||||
node, err := hsdb.GetNodeByMachineKey(key)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
nodes = append(nodes, node)
|
|
||||||
}
|
|
||||||
|
|
||||||
if nodes != nil {
|
|
||||||
stateUpdate := types.StateUpdate{
|
|
||||||
Type: types.StatePeerChanged,
|
Type: types.StatePeerChanged,
|
||||||
ChangeNodes: nodes,
|
ChangeNodes: changedNodes,
|
||||||
Message: "called from db.FailoverNodeRoutesWithNotify",
|
Message: "called from db.EnsureFailoverRouteIsAvailable",
|
||||||
}
|
}, nil
|
||||||
if stateUpdate.Valid() {
|
|
||||||
hsdb.notifier.NotifyAll(stateUpdate)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
|
func failoverRouteReturnUpdate(
|
||||||
changedKeys, err := hsdb.failoverRoute(r)
|
tx *gorm.DB,
|
||||||
|
isConnected map[key.MachinePublic]bool,
|
||||||
|
r *types.Route,
|
||||||
|
) (*types.StateUpdate, error) {
|
||||||
|
changedKeys, err := failoverRoute(tx, isConnected, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Trace().
|
||||||
|
Interface("isConnected", isConnected).
|
||||||
|
Interface("changedKeys", changedKeys).
|
||||||
|
Msg("building route failover")
|
||||||
|
|
||||||
if len(changedKeys) == 0 {
|
if len(changedKeys) == 0 {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var nodes types.Nodes
|
var nodes types.Nodes
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("hostname", r.Node.Hostname).
|
|
||||||
Msg("loading machines with new primary routes from db")
|
|
||||||
|
|
||||||
for _, key := range changedKeys {
|
for _, key := range changedKeys {
|
||||||
node, err := hsdb.getNodeByMachineKey(key)
|
node, err := GetNodeByMachineKey(tx, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes = append(nodes, node)
|
nodes = append(nodes, node)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
return &types.StateUpdate{
|
||||||
Str("hostname", r.Node.Hostname).
|
Type: types.StatePeerChanged,
|
||||||
Msg("notifying peers about primary route change")
|
ChangeNodes: nodes,
|
||||||
|
Message: "called from db.failoverRouteReturnUpdate",
|
||||||
if nodes != nil {
|
}, nil
|
||||||
stateUpdate := types.StateUpdate{
|
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: nodes,
|
|
||||||
Message: "called from db.failoverRouteWithNotify",
|
|
||||||
}
|
|
||||||
if stateUpdate.Valid() {
|
|
||||||
hsdb.notifier.NotifyAll(stateUpdate)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("hostname", r.Node.Hostname).
|
|
||||||
Msg("notified peers about primary route change")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// failoverRoute takes a route that is no longer available,
|
// failoverRoute takes a route that is no longer available,
|
||||||
|
@ -556,12 +510,16 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
|
||||||
//
|
//
|
||||||
// and tries to find a new route to take over its place.
|
// and tries to find a new route to take over its place.
|
||||||
// If the given route was not primary, it returns early.
|
// If the given route was not primary, it returns early.
|
||||||
func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) {
|
func failoverRoute(
|
||||||
|
tx *gorm.DB,
|
||||||
|
isConnected map[key.MachinePublic]bool,
|
||||||
|
r *types.Route,
|
||||||
|
) ([]key.MachinePublic, error) {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// This route is not a primary route, and it isnt
|
// This route is not a primary route, and it is not
|
||||||
// being served to nodes.
|
// being served to nodes.
|
||||||
if !r.IsPrimary {
|
if !r.IsPrimary {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -572,7 +530,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix))
|
routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -585,14 +543,18 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
|
if !route.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if isConnected[route.Node.MachineKey] {
|
||||||
newPrimary = &routes[idx]
|
newPrimary = &routes[idx]
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If a new route was not found/available,
|
// If a new route was not found/available,
|
||||||
// return with an error.
|
// return without an error.
|
||||||
// We do not want to update the database as
|
// We do not want to update the database as
|
||||||
// the one currently marked as primary is the
|
// the one currently marked as primary is the
|
||||||
// best we got.
|
// best we got.
|
||||||
|
@ -606,7 +568,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
||||||
|
|
||||||
// Remove primary from the old route
|
// Remove primary from the old route
|
||||||
r.IsPrimary = false
|
r.IsPrimary = false
|
||||||
err = hsdb.db.Save(&r).Error
|
err = tx.Save(&r).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error disabling new primary route")
|
log.Error().Err(err).Msg("error disabling new primary route")
|
||||||
|
|
||||||
|
@ -619,7 +581,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
||||||
|
|
||||||
// Set primary for the new primary
|
// Set primary for the new primary
|
||||||
newPrimary.IsPrimary = true
|
newPrimary.IsPrimary = true
|
||||||
err = hsdb.db.Save(&newPrimary).Error
|
err = tx.Save(&newPrimary).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error enabling new primary route")
|
log.Error().Err(err).Msg("error enabling new primary route")
|
||||||
|
|
||||||
|
@ -634,19 +596,26 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
|
||||||
return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
|
return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
|
||||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||||
aclPolicy *policy.ACLPolicy,
|
aclPolicy *policy.ACLPolicy,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
) error {
|
) (*types.StateUpdate, error) {
|
||||||
hsdb.mu.Lock()
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||||
defer hsdb.mu.Unlock()
|
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
||||||
|
func EnableAutoApprovedRoutes(
|
||||||
|
tx *gorm.DB,
|
||||||
|
aclPolicy *policy.ACLPolicy,
|
||||||
|
node *types.Node,
|
||||||
|
) (*types.StateUpdate, error) {
|
||||||
if len(node.IPAddresses) == 0 {
|
if len(node.IPAddresses) == 0 {
|
||||||
return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := hsdb.getNodeAdvertisedRoutes(node)
|
routes, err := GetNodeAdvertisedRoutes(tx, node)
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -654,9 +623,11 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Msg("Could not get advertised routes for node")
|
Msg("Could not get advertised routes for node")
|
||||||
|
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
|
||||||
|
|
||||||
approvedRoutes := types.Routes{}
|
approvedRoutes := types.Routes{}
|
||||||
|
|
||||||
for _, advertisedRoute := range routes {
|
for _, advertisedRoute := range routes {
|
||||||
|
@ -673,9 +644,16 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||||
Uint64("nodeId", node.ID).
|
Uint64("nodeId", node.ID).
|
||||||
Msg("Failed to resolve autoApprovers for advertised route")
|
Msg("Failed to resolve autoApprovers for advertised route")
|
||||||
|
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Trace().
|
||||||
|
Str("node", node.Hostname).
|
||||||
|
Str("user", node.User.Name).
|
||||||
|
Strs("routeApprovers", routeApprovers).
|
||||||
|
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
|
||||||
|
Msg("looking up route for autoapproving")
|
||||||
|
|
||||||
for _, approvedAlias := range routeApprovers {
|
for _, approvedAlias := range routeApprovers {
|
||||||
if approvedAlias == node.User.Name {
|
if approvedAlias == node.User.Name {
|
||||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||||
|
@ -687,7 +665,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||||
Str("alias", approvedAlias).
|
Str("alias", approvedAlias).
|
||||||
Msg("Failed to expand alias when processing autoApprovers policy")
|
Msg("Failed to expand alias when processing autoApprovers policy")
|
||||||
|
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||||
|
@ -698,17 +676,25 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
update := &types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
ChangeNodes: types.Nodes{},
|
||||||
|
Message: "created in db.EnableAutoApprovedRoutes",
|
||||||
|
}
|
||||||
|
|
||||||
for _, approvedRoute := range approvedRoutes {
|
for _, approvedRoute := range approvedRoutes {
|
||||||
err := hsdb.enableRoute(uint64(approvedRoute.ID))
|
perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).
|
log.Err(err).
|
||||||
Str("approvedRoute", approvedRoute.String()).
|
Str("approvedRoute", approvedRoute.String()).
|
||||||
Uint64("nodeId", node.ID).
|
Uint64("nodeId", node.ID).
|
||||||
Msg("Failed to enable approved route")
|
Msg("Failed to enable approved route")
|
||||||
|
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return update, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("test", "test_get_route_node")
|
_, err = db.getNode("test", "test_get_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix("10.0.0.0/24")
|
route, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
|
@ -42,7 +42,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
Hostinfo: &hostInfo,
|
Hostinfo: &hostInfo,
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
su, err := db.SaveNodeRoutes(&node)
|
su, err := db.SaveNodeRoutes(&node)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
@ -52,10 +52,11 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(advertisedRoutes), check.Equals, 1)
|
c.Assert(len(advertisedRoutes), check.Equals, 1)
|
||||||
|
|
||||||
err = db.enableRoutes(&node, "192.168.0.0/24")
|
// TODO(kradalby): check state update
|
||||||
|
_, err = db.enableRoutes(&node, "192.168.0.0/24")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
err = db.enableRoutes(&node, "10.0.0.0/24")
|
_, err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,7 +67,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("test", "test_enable_route_node")
|
_, err = db.getNode("test", "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -91,7 +92,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
Hostinfo: &hostInfo,
|
Hostinfo: &hostInfo,
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
sendUpdate, err := db.SaveNodeRoutes(&node)
|
sendUpdate, err := db.SaveNodeRoutes(&node)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
@ -106,10 +107,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
||||||
|
|
||||||
err = db.enableRoutes(&node, "192.168.0.0/24")
|
_, err = db.enableRoutes(&node, "192.168.0.0/24")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
err = db.enableRoutes(&node, "10.0.0.0/24")
|
_, err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes, err := db.GetEnabledRoutes(&node)
|
enabledRoutes, err := db.GetEnabledRoutes(&node)
|
||||||
|
@ -117,14 +118,14 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
c.Assert(len(enabledRoutes), check.Equals, 1)
|
c.Assert(len(enabledRoutes), check.Equals, 1)
|
||||||
|
|
||||||
// Adding it twice will just let it pass through
|
// Adding it twice will just let it pass through
|
||||||
err = db.enableRoutes(&node, "10.0.0.0/24")
|
_, err = db.enableRoutes(&node, "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node)
|
enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
||||||
|
|
||||||
err = db.enableRoutes(&node, "150.0.10.0/25")
|
_, err = db.enableRoutes(&node, "150.0.10.0/25")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node)
|
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node)
|
||||||
|
@ -139,7 +140,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("test", "test_enable_route_node")
|
_, err = db.getNode("test", "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -163,16 +164,16 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
Hostinfo: &hostInfo1,
|
Hostinfo: &hostInfo1,
|
||||||
}
|
}
|
||||||
db.db.Save(&node1)
|
db.DB.Save(&node1)
|
||||||
|
|
||||||
sendUpdate, err := db.SaveNodeRoutes(&node1)
|
sendUpdate, err := db.SaveNodeRoutes(&node1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(sendUpdate, check.Equals, false)
|
c.Assert(sendUpdate, check.Equals, false)
|
||||||
|
|
||||||
err = db.enableRoutes(&node1, route.String())
|
_, err = db.enableRoutes(&node1, route.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = db.enableRoutes(&node1, route2.String())
|
_, err = db.enableRoutes(&node1, route2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
hostInfo2 := tailcfg.Hostinfo{
|
hostInfo2 := tailcfg.Hostinfo{
|
||||||
|
@ -186,13 +187,13 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
Hostinfo: &hostInfo2,
|
Hostinfo: &hostInfo2,
|
||||||
}
|
}
|
||||||
db.db.Save(&node2)
|
db.DB.Save(&node2)
|
||||||
|
|
||||||
sendUpdate, err = db.SaveNodeRoutes(&node2)
|
sendUpdate, err = db.SaveNodeRoutes(&node2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(sendUpdate, check.Equals, false)
|
c.Assert(sendUpdate, check.Equals, false)
|
||||||
|
|
||||||
err = db.enableRoutes(&node2, route2.String())
|
_, err = db.enableRoutes(&node2, route2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
|
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
|
||||||
|
@ -219,7 +220,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNode("test", "test_enable_route_node")
|
_, err = db.getNode("test", "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(
|
prefix, err := netip.ParsePrefix(
|
||||||
|
@ -246,22 +247,23 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
Hostinfo: &hostInfo1,
|
Hostinfo: &hostInfo1,
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
}
|
}
|
||||||
db.db.Save(&node1)
|
db.DB.Save(&node1)
|
||||||
|
|
||||||
sendUpdate, err := db.SaveNodeRoutes(&node1)
|
sendUpdate, err := db.SaveNodeRoutes(&node1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(sendUpdate, check.Equals, false)
|
c.Assert(sendUpdate, check.Equals, false)
|
||||||
|
|
||||||
err = db.enableRoutes(&node1, prefix.String())
|
_, err = db.enableRoutes(&node1, prefix.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = db.enableRoutes(&node1, prefix2.String())
|
_, err = db.enableRoutes(&node1, prefix2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
routes, err := db.GetNodeRoutes(&node1)
|
routes, err := db.GetNodeRoutes(&node1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = db.DeleteRoute(uint64(routes[0].ID))
|
// TODO(kradalby): check stateupdate
|
||||||
|
_, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
|
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
|
||||||
|
@ -269,17 +271,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
|
||||||
|
|
||||||
func TestFailoverRoute(t *testing.T) {
|
func TestFailoverRoute(t *testing.T) {
|
||||||
ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
|
|
||||||
|
|
||||||
// TODO(kradalby): Count/verify updates
|
|
||||||
var sink chan types.StateUpdate
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for range sink {
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
machineKeys := []key.MachinePublic{
|
machineKeys := []key.MachinePublic{
|
||||||
key.NewMachine().Public(),
|
key.NewMachine().Public(),
|
||||||
key.NewMachine().Public(),
|
key.NewMachine().Public(),
|
||||||
|
@ -291,6 +285,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
name string
|
name string
|
||||||
failingRoute types.Route
|
failingRoute types.Route
|
||||||
routes types.Routes
|
routes types.Routes
|
||||||
|
isConnected map[key.MachinePublic]bool
|
||||||
want []key.MachinePublic
|
want []key.MachinePublic
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
|
@ -371,6 +366,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[0],
|
MachineKey: machineKeys[0],
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
routes: types.Routes{
|
routes: types.Routes{
|
||||||
types.Route{
|
types.Route{
|
||||||
|
@ -382,6 +378,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[0],
|
MachineKey: machineKeys[0],
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
types.Route{
|
types.Route{
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
|
@ -392,8 +389,13 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[1],
|
MachineKey: machineKeys[1],
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
isConnected: map[key.MachinePublic]bool{
|
||||||
|
machineKeys[0]: false,
|
||||||
|
machineKeys[1]: true,
|
||||||
|
},
|
||||||
want: []key.MachinePublic{
|
want: []key.MachinePublic{
|
||||||
machineKeys[0],
|
machineKeys[0],
|
||||||
machineKeys[1],
|
machineKeys[1],
|
||||||
|
@ -411,6 +413,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[0],
|
MachineKey: machineKeys[0],
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
routes: types.Routes{
|
routes: types.Routes{
|
||||||
types.Route{
|
types.Route{
|
||||||
|
@ -422,6 +425,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[0],
|
MachineKey: machineKeys[0],
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
types.Route{
|
types.Route{
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
|
@ -432,6 +436,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[1],
|
MachineKey: machineKeys[1],
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: nil,
|
want: nil,
|
||||||
|
@ -448,6 +453,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[1],
|
MachineKey: machineKeys[1],
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
routes: types.Routes{
|
routes: types.Routes{
|
||||||
types.Route{
|
types.Route{
|
||||||
|
@ -459,6 +465,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[0],
|
MachineKey: machineKeys[0],
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
types.Route{
|
types.Route{
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
|
@ -469,6 +476,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[1],
|
MachineKey: machineKeys[1],
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
types.Route{
|
types.Route{
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
|
@ -479,8 +487,14 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[2],
|
MachineKey: machineKeys[2],
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
isConnected: map[key.MachinePublic]bool{
|
||||||
|
machineKeys[0]: true,
|
||||||
|
machineKeys[1]: true,
|
||||||
|
machineKeys[2]: true,
|
||||||
|
},
|
||||||
want: []key.MachinePublic{
|
want: []key.MachinePublic{
|
||||||
machineKeys[1],
|
machineKeys[1],
|
||||||
machineKeys[0],
|
machineKeys[0],
|
||||||
|
@ -498,6 +512,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[0],
|
MachineKey: machineKeys[0],
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
routes: types.Routes{
|
routes: types.Routes{
|
||||||
types.Route{
|
types.Route{
|
||||||
|
@ -509,6 +524,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[0],
|
MachineKey: machineKeys[0],
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
// Offline
|
// Offline
|
||||||
types.Route{
|
types.Route{
|
||||||
|
@ -520,8 +536,13 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[3],
|
MachineKey: machineKeys[3],
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
isConnected: map[key.MachinePublic]bool{
|
||||||
|
machineKeys[0]: true,
|
||||||
|
machineKeys[3]: false,
|
||||||
|
},
|
||||||
want: nil,
|
want: nil,
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
@ -536,6 +557,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[0],
|
MachineKey: machineKeys[0],
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
routes: types.Routes{
|
routes: types.Routes{
|
||||||
types.Route{
|
types.Route{
|
||||||
|
@ -547,6 +569,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[0],
|
MachineKey: machineKeys[0],
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
// Offline
|
// Offline
|
||||||
types.Route{
|
types.Route{
|
||||||
|
@ -558,6 +581,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[3],
|
MachineKey: machineKeys[3],
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
types.Route{
|
types.Route{
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
|
@ -568,14 +592,61 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
MachineKey: machineKeys[1],
|
MachineKey: machineKeys[1],
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
isConnected: map[key.MachinePublic]bool{
|
||||||
|
machineKeys[0]: false,
|
||||||
|
machineKeys[1]: true,
|
||||||
|
machineKeys[3]: false,
|
||||||
|
},
|
||||||
want: []key.MachinePublic{
|
want: []key.MachinePublic{
|
||||||
machineKeys[0],
|
machineKeys[0],
|
||||||
machineKeys[1],
|
machineKeys[1],
|
||||||
},
|
},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "failover-primary-none-enabled",
|
||||||
|
failingRoute: types.Route{
|
||||||
|
Model: gorm.Model{
|
||||||
|
ID: 1,
|
||||||
|
},
|
||||||
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
|
Node: types.Node{
|
||||||
|
MachineKey: machineKeys[0],
|
||||||
|
},
|
||||||
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
routes: types.Routes{
|
||||||
|
types.Route{
|
||||||
|
Model: gorm.Model{
|
||||||
|
ID: 1,
|
||||||
|
},
|
||||||
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
|
Node: types.Node{
|
||||||
|
MachineKey: machineKeys[0],
|
||||||
|
},
|
||||||
|
IsPrimary: true,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
// not enabled
|
||||||
|
types.Route{
|
||||||
|
Model: gorm.Model{
|
||||||
|
ID: 2,
|
||||||
|
},
|
||||||
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
|
Node: types.Node{
|
||||||
|
MachineKey: machineKeys[1],
|
||||||
|
},
|
||||||
|
IsPrimary: false,
|
||||||
|
Enabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -583,13 +654,14 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
tmpDir, err := os.MkdirTemp("", "failover-db-test")
|
tmpDir, err := os.MkdirTemp("", "failover-db-test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
notif := notifier.NewNotifier()
|
|
||||||
|
|
||||||
db, err = NewHeadscaleDatabase(
|
db, err = NewHeadscaleDatabase(
|
||||||
"sqlite3",
|
types.DatabaseConfig{
|
||||||
tmpDir+"/headscale_test.db",
|
Type: "sqlite3",
|
||||||
false,
|
Sqlite: types.SqliteConfig{
|
||||||
notif,
|
Path: tmpDir + "/headscale_test.db",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
notifier.NewNotifier(),
|
||||||
[]netip.Prefix{
|
[]netip.Prefix{
|
||||||
netip.MustParsePrefix("10.27.0.0/23"),
|
netip.MustParsePrefix("10.27.0.0/23"),
|
||||||
},
|
},
|
||||||
|
@ -597,23 +669,15 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
)
|
)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Pretend that all the nodes are connected to control
|
|
||||||
for idx, key := range machineKeys {
|
|
||||||
// Pretend one node is offline
|
|
||||||
if idx == 3 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
notif.AddNode(key, sink)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, route := range tt.routes {
|
for _, route := range tt.routes {
|
||||||
if err := db.db.Save(&route).Error; err != nil {
|
if err := db.DB.Save(&route).Error; err != nil {
|
||||||
t.Fatalf("failed to create route: %s", err)
|
t.Fatalf("failed to create route: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := db.failoverRoute(&tt.failingRoute)
|
got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) {
|
||||||
|
return failoverRoute(tx, tt.isConnected, &tt.failingRoute)
|
||||||
|
})
|
||||||
|
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -627,3 +691,231 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// func TestDisableRouteFailover(t *testing.T) {
|
||||||
|
// machineKeys := []key.MachinePublic{
|
||||||
|
// key.NewMachine().Public(),
|
||||||
|
// key.NewMachine().Public(),
|
||||||
|
// key.NewMachine().Public(),
|
||||||
|
// key.NewMachine().Public(),
|
||||||
|
// }
|
||||||
|
|
||||||
|
// tests := []struct {
|
||||||
|
// name string
|
||||||
|
// nodes types.Nodes
|
||||||
|
|
||||||
|
// routeID uint64
|
||||||
|
// isConnected map[key.MachinePublic]bool
|
||||||
|
|
||||||
|
// wantMachineKey key.MachinePublic
|
||||||
|
// wantErr string
|
||||||
|
// }{
|
||||||
|
// {
|
||||||
|
// name: "single-route",
|
||||||
|
// nodes: types.Nodes{
|
||||||
|
// &types.Node{
|
||||||
|
// ID: 0,
|
||||||
|
// MachineKey: machineKeys[0],
|
||||||
|
// Routes: []types.Route{
|
||||||
|
// {
|
||||||
|
// Model: gorm.Model{
|
||||||
|
// ID: 1,
|
||||||
|
// },
|
||||||
|
// Prefix: ipp("10.0.0.0/24"),
|
||||||
|
// Node: types.Node{
|
||||||
|
// MachineKey: machineKeys[0],
|
||||||
|
// },
|
||||||
|
// IsPrimary: true,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
// RoutableIPs: []netip.Prefix{
|
||||||
|
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// routeID: 1,
|
||||||
|
// wantMachineKey: machineKeys[0],
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "failover-simple",
|
||||||
|
// nodes: types.Nodes{
|
||||||
|
// &types.Node{
|
||||||
|
// ID: 0,
|
||||||
|
// MachineKey: machineKeys[0],
|
||||||
|
// Routes: []types.Route{
|
||||||
|
// {
|
||||||
|
// Model: gorm.Model{
|
||||||
|
// ID: 1,
|
||||||
|
// },
|
||||||
|
// Prefix: ipp("10.0.0.0/24"),
|
||||||
|
// IsPrimary: true,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
// RoutableIPs: []netip.Prefix{
|
||||||
|
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// &types.Node{
|
||||||
|
// ID: 1,
|
||||||
|
// MachineKey: machineKeys[1],
|
||||||
|
// Routes: []types.Route{
|
||||||
|
// {
|
||||||
|
// Model: gorm.Model{
|
||||||
|
// ID: 2,
|
||||||
|
// },
|
||||||
|
// Prefix: ipp("10.0.0.0/24"),
|
||||||
|
// IsPrimary: false,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
// RoutableIPs: []netip.Prefix{
|
||||||
|
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// routeID: 1,
|
||||||
|
// wantMachineKey: machineKeys[1],
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "no-failover-offline",
|
||||||
|
// nodes: types.Nodes{
|
||||||
|
// &types.Node{
|
||||||
|
// ID: 0,
|
||||||
|
// MachineKey: machineKeys[0],
|
||||||
|
// Routes: []types.Route{
|
||||||
|
// {
|
||||||
|
// Model: gorm.Model{
|
||||||
|
// ID: 1,
|
||||||
|
// },
|
||||||
|
// Prefix: ipp("10.0.0.0/24"),
|
||||||
|
// IsPrimary: true,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
// RoutableIPs: []netip.Prefix{
|
||||||
|
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// &types.Node{
|
||||||
|
// ID: 1,
|
||||||
|
// MachineKey: machineKeys[1],
|
||||||
|
// Routes: []types.Route{
|
||||||
|
// {
|
||||||
|
// Model: gorm.Model{
|
||||||
|
// ID: 2,
|
||||||
|
// },
|
||||||
|
// Prefix: ipp("10.0.0.0/24"),
|
||||||
|
// IsPrimary: false,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
// RoutableIPs: []netip.Prefix{
|
||||||
|
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// isConnected: map[key.MachinePublic]bool{
|
||||||
|
// machineKeys[0]: true,
|
||||||
|
// machineKeys[1]: false,
|
||||||
|
// },
|
||||||
|
// routeID: 1,
|
||||||
|
// wantMachineKey: machineKeys[1],
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "failover-to-online",
|
||||||
|
// nodes: types.Nodes{
|
||||||
|
// &types.Node{
|
||||||
|
// ID: 0,
|
||||||
|
// MachineKey: machineKeys[0],
|
||||||
|
// Routes: []types.Route{
|
||||||
|
// {
|
||||||
|
// Model: gorm.Model{
|
||||||
|
// ID: 1,
|
||||||
|
// },
|
||||||
|
// Prefix: ipp("10.0.0.0/24"),
|
||||||
|
// IsPrimary: true,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
// RoutableIPs: []netip.Prefix{
|
||||||
|
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// &types.Node{
|
||||||
|
// ID: 1,
|
||||||
|
// MachineKey: machineKeys[1],
|
||||||
|
// Routes: []types.Route{
|
||||||
|
// {
|
||||||
|
// Model: gorm.Model{
|
||||||
|
// ID: 2,
|
||||||
|
// },
|
||||||
|
// Prefix: ipp("10.0.0.0/24"),
|
||||||
|
// IsPrimary: false,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
// RoutableIPs: []netip.Prefix{
|
||||||
|
// netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// isConnected: map[key.MachinePublic]bool{
|
||||||
|
// machineKeys[0]: true,
|
||||||
|
// machineKeys[1]: true,
|
||||||
|
// },
|
||||||
|
// routeID: 1,
|
||||||
|
// wantMachineKey: machineKeys[1],
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for _, tt := range tests {
|
||||||
|
// t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "")
|
||||||
|
// assert.NoError(t, err)
|
||||||
|
|
||||||
|
// // bootstrap db
|
||||||
|
// datab.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// for _, node := range tt.nodes {
|
||||||
|
// err := tx.Save(node).Error
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
|
||||||
|
// _, err = SaveNodeRoutes(tx, node)
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return nil
|
||||||
|
// })
|
||||||
|
|
||||||
|
// got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||||
|
// return DisableRoute(tx, tt.routeID, tt.isConnected)
|
||||||
|
// })
|
||||||
|
|
||||||
|
// // if (err.Error() != "") != tt.wantErr {
|
||||||
|
// // t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
||||||
|
// // return
|
||||||
|
// // }
|
||||||
|
|
||||||
|
// if len(got.ChangeNodes) != 1 {
|
||||||
|
// t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" {
|
||||||
|
// t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff)
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,9 +46,12 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||||
log.Printf("database path: %s", tmpDir+"/headscale_test.db")
|
log.Printf("database path: %s", tmpDir+"/headscale_test.db")
|
||||||
|
|
||||||
db, err = NewHeadscaleDatabase(
|
db, err = NewHeadscaleDatabase(
|
||||||
"sqlite3",
|
types.DatabaseConfig{
|
||||||
tmpDir+"/headscale_test.db",
|
Type: "sqlite3",
|
||||||
false,
|
Sqlite: types.SqliteConfig{
|
||||||
|
Path: tmpDir + "/headscale_test.db",
|
||||||
|
},
|
||||||
|
},
|
||||||
notifier.NewNotifier(),
|
notifier.NewNotifier(),
|
||||||
[]netip.Prefix{
|
[]netip.Prefix{
|
||||||
netip.MustParsePrefix("10.27.0.0/23"),
|
netip.MustParsePrefix("10.27.0.0/23"),
|
||||||
|
|
|
@ -15,22 +15,25 @@ var (
|
||||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
||||||
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.User, error) {
|
||||||
|
return CreateUser(tx, name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// CreateUser creates a new User. Returns error if could not be created
|
// CreateUser creates a new User. Returns error if could not be created
|
||||||
// or another user already exists.
|
// or another user already exists.
|
||||||
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||||
hsdb.mu.Lock()
|
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
err := util.CheckForFQDNRules(name)
|
err := util.CheckForFQDNRules(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
user := types.User{}
|
user := types.User{}
|
||||||
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
|
if err := tx.Where("name = ?", name).First(&user).Error; err == nil {
|
||||||
return nil, ErrUserExists
|
return nil, ErrUserExists
|
||||||
}
|
}
|
||||||
user.Name = name
|
user.Name = name
|
||||||
if err := hsdb.db.Create(&user).Error; err != nil {
|
if err := tx.Create(&user).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "CreateUser").
|
Str("func", "CreateUser").
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -42,18 +45,21 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||||
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
|
return DestroyUser(tx, name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// DestroyUser destroys a User. Returns error if the User does
|
// DestroyUser destroys a User. Returns error if the User does
|
||||||
// not exist or if there are nodes associated with it.
|
// not exist or if there are nodes associated with it.
|
||||||
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
func DestroyUser(tx *gorm.DB, name string) error {
|
||||||
hsdb.mu.Lock()
|
user, err := GetUser(tx, name)
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
user, err := hsdb.getUser(name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrUserNotFound
|
return ErrUserNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes, err := hsdb.listNodesByUser(name)
|
nodes, err := ListNodesByUser(tx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -61,32 +67,35 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||||
return ErrUserStillHasNodes
|
return ErrUserStillHasNodes
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, err := hsdb.listPreAuthKeys(name)
|
keys, err := ListPreAuthKeys(tx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
err = hsdb.destroyPreAuthKey(key)
|
err = DestroyPreAuthKey(tx, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil {
|
if result := tx.Unscoped().Delete(&user); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||||
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
|
return RenameUser(tx, oldName, newName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// RenameUser renames a User. Returns error if the User does
|
// RenameUser renames a User. Returns error if the User does
|
||||||
// not exist or if another User exists with the new name.
|
// not exist or if another User exists with the new name.
|
||||||
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
||||||
hsdb.mu.Lock()
|
|
||||||
defer hsdb.mu.Unlock()
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
oldUser, err := hsdb.getUser(oldName)
|
oldUser, err := GetUser(tx, oldName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -94,7 +103,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = hsdb.getUser(newName)
|
_, err = GetUser(tx, newName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return ErrUserExists
|
return ErrUserExists
|
||||||
}
|
}
|
||||||
|
@ -104,24 +113,22 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||||
|
|
||||||
oldUser.Name = newName
|
oldUser.Name = newName
|
||||||
|
|
||||||
if result := hsdb.db.Save(&oldUser); result.Error != nil {
|
if result := tx.Save(&oldUser); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUser fetches a user by name.
|
|
||||||
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
|
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
|
||||||
hsdb.mu.RLock()
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||||
defer hsdb.mu.RUnlock()
|
return GetUser(rx, name)
|
||||||
|
})
|
||||||
return hsdb.getUser(name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
|
func GetUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||||
user := types.User{}
|
user := types.User{}
|
||||||
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
|
if result := tx.First(&user, "name = ?", name); errors.Is(
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
) {
|
) {
|
||||||
|
@ -131,17 +138,16 @@ func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListUsers gets all the existing users.
|
|
||||||
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
||||||
hsdb.mu.RLock()
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||||||
defer hsdb.mu.RUnlock()
|
return ListUsers(rx)
|
||||||
|
})
|
||||||
return hsdb.listUsers()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
|
// ListUsers gets all the existing users.
|
||||||
|
func ListUsers(tx *gorm.DB) ([]types.User, error) {
|
||||||
users := []types.User{}
|
users := []types.User{}
|
||||||
if err := hsdb.db.Find(&users).Error; err != nil {
|
if err := tx.Find(&users).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,46 +155,42 @@ func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListNodesByUser gets all the nodes in a given user.
|
// ListNodesByUser gets all the nodes in a given user.
|
||||||
func (hsdb *HSDatabase) ListNodesByUser(name string) (types.Nodes, error) {
|
func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
|
||||||
hsdb.mu.RLock()
|
|
||||||
defer hsdb.mu.RUnlock()
|
|
||||||
|
|
||||||
return hsdb.listNodesByUser(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) listNodesByUser(name string) (types.Nodes, error) {
|
|
||||||
err := util.CheckForFQDNRules(name)
|
err := util.CheckForFQDNRules(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
user, err := hsdb.getUser(name)
|
user, err := GetUser(tx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
|
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes, nil
|
return nodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AssignNodeToUser assigns a Node to a user.
|
|
||||||
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error {
|
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error {
|
||||||
hsdb.mu.Lock()
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
defer hsdb.mu.Unlock()
|
return AssignNodeToUser(tx, node, username)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssignNodeToUser assigns a Node to a user.
|
||||||
|
func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
|
||||||
err := util.CheckForFQDNRules(username)
|
err := util.CheckForFQDNRules(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
user, err := hsdb.getUser(username)
|
user, err := GetUser(tx, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
node.User = *user
|
node.User = *user
|
||||||
if result := hsdb.db.Save(&node); result.Error != nil {
|
if result := tx.Save(&node); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
err = db.DestroyUser("test")
|
err = db.DestroyUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
result := db.db.Preload("User").First(&pak, "key = ?", pak.Key)
|
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
|
||||||
// destroying a user also deletes all associated preauthkeys
|
// destroying a user also deletes all associated preauthkeys
|
||||||
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
err = db.DestroyUser("test")
|
err = db.DestroyUser("test")
|
||||||
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
||||||
|
@ -105,7 +105,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
db.db.Save(&node)
|
db.DB.Save(&node)
|
||||||
c.Assert(node.UserID, check.Equals, oldUser.ID)
|
c.Assert(node.UserID, check.Equals, oldUser.ID)
|
||||||
|
|
||||||
err = db.AssignNodeToUser(&node, newUser.Name)
|
err = db.AssignNodeToUser(&node, newUser.Name)
|
||||||
|
|
|
@ -84,6 +84,8 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
||||||
RegionID: d.cfg.ServerRegionID,
|
RegionID: d.cfg.ServerRegionID,
|
||||||
HostName: host,
|
HostName: host,
|
||||||
DERPPort: port,
|
DERPPort: port,
|
||||||
|
IPv4: d.cfg.IPv4,
|
||||||
|
IPv6: d.cfg.IPv6,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -99,6 +101,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
||||||
localDERPregion.Nodes[0].STUNPort = portSTUN
|
localDERPregion.Nodes[0].STUNPort = portSTUN
|
||||||
|
|
||||||
log.Info().Caller().Msgf("DERP region: %+v", localDERPregion)
|
log.Info().Caller().Msgf("DERP region: %+v", localDERPregion)
|
||||||
|
log.Info().Caller().Msgf("DERP Nodes[0]: %+v", localDERPregion.Nodes[0])
|
||||||
|
|
||||||
return localDERPregion, nil
|
return localDERPregion, nil
|
||||||
}
|
}
|
||||||
|
@ -208,6 +211,7 @@ func DERPProbeHandler(
|
||||||
// The initial implementation is here https://github.com/tailscale/tailscale/pull/1406
|
// The initial implementation is here https://github.com/tailscale/tailscale/pull/1406
|
||||||
// They have a cache, but not clear if that is really necessary at Headscale, uh, scale.
|
// They have a cache, but not clear if that is really necessary at Headscale, uh, scale.
|
||||||
// An example implementation is found here https://derp.tailscale.com/bootstrap-dns
|
// An example implementation is found here https://derp.tailscale.com/bootstrap-dns
|
||||||
|
// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL.
|
||||||
func DERPBootstrapDNSHandler(
|
func DERPBootstrapDNSHandler(
|
||||||
derpMap *tailcfg.DERPMap,
|
derpMap *tailcfg.DERPMap,
|
||||||
) func(http.ResponseWriter, *http.Request) {
|
) func(http.ResponseWriter, *http.Request) {
|
||||||
|
|
|
@ -8,11 +8,13 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
@ -136,12 +138,14 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ExpirePreAuthKeyRequest,
|
request *v1.ExpirePreAuthKeyRequest,
|
||||||
) (*v1.ExpirePreAuthKeyResponse, error) {
|
) (*v1.ExpirePreAuthKeyResponse, error) {
|
||||||
preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key)
|
err := api.h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
if err != nil {
|
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
|
||||||
return nil, err
|
if err != nil {
|
||||||
}
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = api.h.db.ExpirePreAuthKey(preAuthKey)
|
return db.ExpirePreAuthKey(tx, preAuthKey)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -181,17 +185,31 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
node, err := api.h.db.RegisterNodeFromAuthCallback(
|
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
api.h.registrationCache,
|
return db.RegisterNodeFromAuthCallback(
|
||||||
mkey,
|
tx,
|
||||||
request.GetUser(),
|
api.h.registrationCache,
|
||||||
nil,
|
mkey,
|
||||||
util.RegisterMethodCLI,
|
request.GetUser(),
|
||||||
)
|
nil,
|
||||||
|
util.RegisterMethodCLI,
|
||||||
|
api.h.cfg.IPPrefixes,
|
||||||
|
)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stateUpdate := types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
ChangeNodes: types.Nodes{node},
|
||||||
|
Message: "called from api.RegisterNode",
|
||||||
|
}
|
||||||
|
if stateUpdate.Valid() {
|
||||||
|
ctx := types.NotifyCtx(ctx, "cli-registernode", node.Hostname)
|
||||||
|
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,25 +235,35 @@ func (api headscaleV1APIServer) SetTags(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.SetTagsRequest,
|
request *v1.SetTagsRequest,
|
||||||
) (*v1.SetTagsResponse, error) {
|
) (*v1.SetTagsResponse, error) {
|
||||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tag := range request.GetTags() {
|
for _, tag := range request.GetTags() {
|
||||||
err := validateTag(tag)
|
err := validateTag(tag)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &v1.SetTagsResponse{
|
return nil, err
|
||||||
Node: nil,
|
|
||||||
}, status.Error(codes.InvalidArgument, err.Error())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.db.SetTags(node, request.GetTags())
|
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
err := db.SetTags(tx, request.GetNodeId(), request.GetTags())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.GetNodeByID(tx, request.GetNodeId())
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &v1.SetTagsResponse{
|
return &v1.SetTagsResponse{
|
||||||
Node: nil,
|
Node: nil,
|
||||||
}, status.Error(codes.Internal, err.Error())
|
}, status.Error(codes.InvalidArgument, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
stateUpdate := types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
ChangeNodes: types.Nodes{node},
|
||||||
|
Message: "called from api.SetTags",
|
||||||
|
}
|
||||||
|
if stateUpdate.Valid() {
|
||||||
|
ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname)
|
||||||
|
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
|
@ -270,11 +298,21 @@ func (api headscaleV1APIServer) DeleteNode(
|
||||||
|
|
||||||
err = api.h.db.DeleteNode(
|
err = api.h.db.DeleteNode(
|
||||||
node,
|
node,
|
||||||
|
api.h.nodeNotifier.ConnectedMap(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stateUpdate := types.StateUpdate{
|
||||||
|
Type: types.StatePeerRemoved,
|
||||||
|
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
|
||||||
|
}
|
||||||
|
if stateUpdate.Valid() {
|
||||||
|
ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
|
||||||
|
api.h.nodeNotifier.NotifyAll(ctx, stateUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.DeleteNodeResponse{}, nil
|
return &v1.DeleteNodeResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -282,17 +320,38 @@ func (api headscaleV1APIServer) ExpireNode(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ExpireNodeRequest,
|
request *v1.ExpireNodeRequest,
|
||||||
) (*v1.ExpireNodeResponse, error) {
|
) (*v1.ExpireNodeResponse, error) {
|
||||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
now := time.Now()
|
||||||
|
|
||||||
|
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
db.NodeSetExpiry(
|
||||||
|
tx,
|
||||||
|
request.GetNodeId(),
|
||||||
|
now,
|
||||||
|
)
|
||||||
|
|
||||||
|
return db.GetNodeByID(tx, request.GetNodeId())
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
selfUpdate := types.StateUpdate{
|
||||||
|
Type: types.StateSelfUpdate,
|
||||||
|
ChangeNodes: types.Nodes{node},
|
||||||
|
}
|
||||||
|
if selfUpdate.Valid() {
|
||||||
|
ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
|
||||||
|
api.h.nodeNotifier.NotifyByMachineKey(
|
||||||
|
ctx,
|
||||||
|
selfUpdate,
|
||||||
|
node.MachineKey)
|
||||||
|
}
|
||||||
|
|
||||||
api.h.db.NodeSetExpiry(
|
stateUpdate := types.StateUpdateExpire(node.ID, now)
|
||||||
node,
|
if stateUpdate.Valid() {
|
||||||
now,
|
ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
|
||||||
)
|
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||||
|
}
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
|
@ -306,17 +365,30 @@ func (api headscaleV1APIServer) RenameNode(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.RenameNodeRequest,
|
request *v1.RenameNodeRequest,
|
||||||
) (*v1.RenameNodeResponse, error) {
|
) (*v1.RenameNodeResponse, error) {
|
||||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
err := db.RenameNode(
|
||||||
|
tx,
|
||||||
|
request.GetNodeId(),
|
||||||
|
request.GetNewName(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.GetNodeByID(tx, request.GetNodeId())
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.db.RenameNode(
|
stateUpdate := types.StateUpdate{
|
||||||
node,
|
Type: types.StatePeerChanged,
|
||||||
request.GetNewName(),
|
ChangeNodes: types.Nodes{node},
|
||||||
)
|
Message: "called from api.RenameNode",
|
||||||
if err != nil {
|
}
|
||||||
return nil, err
|
if stateUpdate.Valid() {
|
||||||
|
ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
|
||||||
|
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
|
@ -331,8 +403,11 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ListNodesRequest,
|
request *v1.ListNodesRequest,
|
||||||
) (*v1.ListNodesResponse, error) {
|
) (*v1.ListNodesResponse, error) {
|
||||||
|
isConnected := api.h.nodeNotifier.ConnectedMap()
|
||||||
if request.GetUser() != "" {
|
if request.GetUser() != "" {
|
||||||
nodes, err := api.h.db.ListNodesByUser(request.GetUser())
|
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
|
return db.ListNodesByUser(rx, request.GetUser())
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -343,7 +418,7 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
|
|
||||||
// Populate the online field based on
|
// Populate the online field based on
|
||||||
// currently connected nodes.
|
// currently connected nodes.
|
||||||
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
|
resp.Online = isConnected[node.MachineKey]
|
||||||
|
|
||||||
response[index] = resp
|
response[index] = resp
|
||||||
}
|
}
|
||||||
|
@ -362,10 +437,10 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
|
|
||||||
// Populate the online field based on
|
// Populate the online field based on
|
||||||
// currently connected nodes.
|
// currently connected nodes.
|
||||||
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
|
resp.Online = isConnected[node.MachineKey]
|
||||||
|
|
||||||
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
|
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
|
||||||
&node,
|
node,
|
||||||
)
|
)
|
||||||
resp.InvalidTags = invalidTags
|
resp.InvalidTags = invalidTags
|
||||||
resp.ValidTags = validTags
|
resp.ValidTags = validTags
|
||||||
|
@ -396,7 +471,9 @@ func (api headscaleV1APIServer) GetRoutes(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.GetRoutesRequest,
|
request *v1.GetRoutesRequest,
|
||||||
) (*v1.GetRoutesResponse, error) {
|
) (*v1.GetRoutesResponse, error) {
|
||||||
routes, err := api.h.db.GetRoutes()
|
routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||||
|
return db.GetRoutes(rx)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -410,11 +487,19 @@ func (api headscaleV1APIServer) EnableRoute(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.EnableRouteRequest,
|
request *v1.EnableRouteRequest,
|
||||||
) (*v1.EnableRouteResponse, error) {
|
) (*v1.EnableRouteResponse, error) {
|
||||||
err := api.h.db.EnableRoute(request.GetRouteId())
|
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||||
|
return db.EnableRoute(tx, request.GetRouteId())
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if update != nil && update.Valid() {
|
||||||
|
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
|
||||||
|
api.h.nodeNotifier.NotifyAll(
|
||||||
|
ctx, *update)
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.EnableRouteResponse{}, nil
|
return &v1.EnableRouteResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -422,11 +507,19 @@ func (api headscaleV1APIServer) DisableRoute(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DisableRouteRequest,
|
request *v1.DisableRouteRequest,
|
||||||
) (*v1.DisableRouteResponse, error) {
|
) (*v1.DisableRouteResponse, error) {
|
||||||
err := api.h.db.DisableRoute(request.GetRouteId())
|
isConnected := api.h.nodeNotifier.ConnectedMap()
|
||||||
|
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||||
|
return db.DisableRoute(tx, request.GetRouteId(), isConnected)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if update != nil && update.Valid() {
|
||||||
|
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
|
||||||
|
api.h.nodeNotifier.NotifyAll(ctx, *update)
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.DisableRouteResponse{}, nil
|
return &v1.DisableRouteResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -453,11 +546,19 @@ func (api headscaleV1APIServer) DeleteRoute(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteRouteRequest,
|
request *v1.DeleteRouteRequest,
|
||||||
) (*v1.DeleteRouteResponse, error) {
|
) (*v1.DeleteRouteResponse, error) {
|
||||||
err := api.h.db.DeleteRoute(request.GetRouteId())
|
isConnected := api.h.nodeNotifier.ConnectedMap()
|
||||||
|
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||||
|
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if update != nil && update.Valid() {
|
||||||
|
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
|
||||||
|
api.h.nodeNotifier.NotifyWithIgnore(ctx, *update)
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.DeleteRouteResponse{}, nil
|
return &v1.DeleteRouteResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -272,13 +272,26 @@ func (m *Mapper) LiteMapResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
|
messages ...string,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
|
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress)
|
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
||||||
|
pol,
|
||||||
|
node,
|
||||||
|
nodeMapToList(m.peers),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
|
||||||
|
resp.SSHPolicy = sshPolicy
|
||||||
|
|
||||||
|
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mapper) KeepAliveResponse(
|
func (m *Mapper) KeepAliveResponse(
|
||||||
|
@ -380,9 +393,7 @@ func (m *Mapper) PeerChangedPatchResponse(
|
||||||
}
|
}
|
||||||
|
|
||||||
if patches, ok := m.patches[uint64(change.NodeID)]; ok {
|
if patches, ok := m.patches[uint64(change.NodeID)]; ok {
|
||||||
patches := append(patches, p)
|
m.patches[uint64(change.NodeID)] = append(patches, p)
|
||||||
|
|
||||||
m.patches[uint64(change.NodeID)] = patches
|
|
||||||
} else {
|
} else {
|
||||||
m.patches[uint64(change.NodeID)] = []patch{p}
|
m.patches[uint64(change.NodeID)] = []patch{p}
|
||||||
}
|
}
|
||||||
|
@ -458,6 +469,8 @@ func (m *Mapper) marshalMapResponse(
|
||||||
switch {
|
switch {
|
||||||
case resp.Peers != nil && len(resp.Peers) > 0:
|
case resp.Peers != nil && len(resp.Peers) > 0:
|
||||||
responseType = "full"
|
responseType = "full"
|
||||||
|
case isSelfUpdate(messages...):
|
||||||
|
responseType = "self"
|
||||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil:
|
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil:
|
||||||
responseType = "lite"
|
responseType = "lite"
|
||||||
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
|
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
|
||||||
|
@ -656,3 +669,13 @@ func appendPeerChanges(
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isSelfUpdate(messages ...string) bool {
|
||||||
|
for _, message := range messages {
|
||||||
|
if strings.Contains(message, types.SelfUpdateIdentifier) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -72,7 +72,7 @@ func tailNode(
|
||||||
}
|
}
|
||||||
|
|
||||||
var derp string
|
var derp string
|
||||||
if node.Hostinfo.NetInfo != nil {
|
if node.Hostinfo != nil && node.Hostinfo.NetInfo != nil {
|
||||||
derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP)
|
derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP)
|
||||||
} else {
|
} else {
|
||||||
derp = "127.3.3.40:0" // Zero means disconnected or unknown.
|
derp = "127.3.3.40:0" // Zero means disconnected or unknown.
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package notifier
|
package notifier
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -12,26 +13,30 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Notifier struct {
|
type Notifier struct {
|
||||||
l sync.RWMutex
|
l sync.RWMutex
|
||||||
nodes map[string]chan<- types.StateUpdate
|
nodes map[string]chan<- types.StateUpdate
|
||||||
|
connected map[key.MachinePublic]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNotifier() *Notifier {
|
func NewNotifier() *Notifier {
|
||||||
return &Notifier{}
|
return &Notifier{
|
||||||
|
nodes: make(map[string]chan<- types.StateUpdate),
|
||||||
|
connected: make(map[key.MachinePublic]bool),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
|
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
|
||||||
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
|
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
|
||||||
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to add node")
|
defer log.Trace().
|
||||||
|
Caller().
|
||||||
|
Str("key", machineKey.ShortString()).
|
||||||
|
Msg("releasing lock to add node")
|
||||||
|
|
||||||
n.l.Lock()
|
n.l.Lock()
|
||||||
defer n.l.Unlock()
|
defer n.l.Unlock()
|
||||||
|
|
||||||
if n.nodes == nil {
|
|
||||||
n.nodes = make(map[string]chan<- types.StateUpdate)
|
|
||||||
}
|
|
||||||
|
|
||||||
n.nodes[machineKey.String()] = c
|
n.nodes[machineKey.String()] = c
|
||||||
|
n.connected[machineKey] = true
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("machine_key", machineKey.ShortString()).
|
Str("machine_key", machineKey.ShortString()).
|
||||||
|
@ -41,16 +46,20 @@ func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpd
|
||||||
|
|
||||||
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
|
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
|
||||||
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
|
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
|
||||||
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to remove node")
|
defer log.Trace().
|
||||||
|
Caller().
|
||||||
|
Str("key", machineKey.ShortString()).
|
||||||
|
Msg("releasing lock to remove node")
|
||||||
|
|
||||||
n.l.Lock()
|
n.l.Lock()
|
||||||
defer n.l.Unlock()
|
defer n.l.Unlock()
|
||||||
|
|
||||||
if n.nodes == nil {
|
if len(n.nodes) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(n.nodes, machineKey.String())
|
delete(n.nodes, machineKey.String())
|
||||||
|
n.connected[machineKey] = false
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("machine_key", machineKey.ShortString()).
|
Str("machine_key", machineKey.ShortString()).
|
||||||
|
@ -64,23 +73,28 @@ func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
|
||||||
n.l.RLock()
|
n.l.RLock()
|
||||||
defer n.l.RUnlock()
|
defer n.l.RUnlock()
|
||||||
|
|
||||||
if _, ok := n.nodes[machineKey.String()]; ok {
|
return n.connected[machineKey]
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) NotifyAll(update types.StateUpdate) {
|
// TODO(kradalby): This returns a pointer and can be dangerous.
|
||||||
n.NotifyWithIgnore(update)
|
func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool {
|
||||||
|
return n.connected
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) {
|
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
|
||||||
|
n.NotifyWithIgnore(ctx, update)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) NotifyWithIgnore(
|
||||||
|
ctx context.Context,
|
||||||
|
update types.StateUpdate,
|
||||||
|
ignore ...string,
|
||||||
|
) {
|
||||||
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
|
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
|
||||||
defer log.Trace().
|
defer log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Interface("type", update.Type).
|
Interface("type", update.Type).
|
||||||
Msg("releasing lock, finished notifing")
|
Msg("releasing lock, finished notifying")
|
||||||
|
|
||||||
n.l.RLock()
|
n.l.RLock()
|
||||||
defer n.l.RUnlock()
|
defer n.l.RUnlock()
|
||||||
|
@ -90,23 +104,58 @@ func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Caller().Str("machine", key).Strs("ignoring", ignore).Msg("sending update")
|
select {
|
||||||
c <- update
|
case <-ctx.Done():
|
||||||
|
log.Error().
|
||||||
|
Err(ctx.Err()).
|
||||||
|
Str("mkey", key).
|
||||||
|
Any("origin", ctx.Value("origin")).
|
||||||
|
Any("hostname", ctx.Value("hostname")).
|
||||||
|
Msgf("update not sent, context cancelled")
|
||||||
|
|
||||||
|
return
|
||||||
|
case c <- update:
|
||||||
|
log.Trace().
|
||||||
|
Str("mkey", key).
|
||||||
|
Any("origin", ctx.Value("origin")).
|
||||||
|
Any("hostname", ctx.Value("hostname")).
|
||||||
|
Msgf("update successfully sent on chan")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) NotifyByMachineKey(update types.StateUpdate, mKey key.MachinePublic) {
|
func (n *Notifier) NotifyByMachineKey(
|
||||||
|
ctx context.Context,
|
||||||
|
update types.StateUpdate,
|
||||||
|
mKey key.MachinePublic,
|
||||||
|
) {
|
||||||
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
|
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
|
||||||
defer log.Trace().
|
defer log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Interface("type", update.Type).
|
Interface("type", update.Type).
|
||||||
Msg("releasing lock, finished notifing")
|
Msg("releasing lock, finished notifying")
|
||||||
|
|
||||||
n.l.RLock()
|
n.l.RLock()
|
||||||
defer n.l.RUnlock()
|
defer n.l.RUnlock()
|
||||||
|
|
||||||
if c, ok := n.nodes[mKey.String()]; ok {
|
if c, ok := n.nodes[mKey.String()]; ok {
|
||||||
c <- update
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Error().
|
||||||
|
Err(ctx.Err()).
|
||||||
|
Str("mkey", mKey.String()).
|
||||||
|
Any("origin", ctx.Value("origin")).
|
||||||
|
Any("hostname", ctx.Value("hostname")).
|
||||||
|
Msgf("update not sent, context cancelled")
|
||||||
|
|
||||||
|
return
|
||||||
|
case c <- update:
|
||||||
|
log.Trace().
|
||||||
|
Str("mkey", mKey.String()).
|
||||||
|
Any("origin", ctx.Value("origin")).
|
||||||
|
Any("hostname", ctx.Value("hostname")).
|
||||||
|
Msgf("update successfully sent on chan")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
"gorm.io/gorm"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -569,7 +570,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Msg("node already registered, reauthenticating")
|
Msg("node already registered, reauthenticating")
|
||||||
|
|
||||||
err := h.db.NodeSetExpiry(node, expiry)
|
err := h.db.NodeSetExpiry(node.ID, expiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.LogErr(err, "Failed to refresh node")
|
util.LogErr(err, "Failed to refresh node")
|
||||||
http.Error(
|
http.Error(
|
||||||
|
@ -623,6 +624,12 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
||||||
util.LogErr(err, "Failed to write response")
|
util.LogErr(err, "Failed to write response")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stateUpdate := types.StateUpdateExpire(node.ID, expiry)
|
||||||
|
if stateUpdate.Valid() {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na")
|
||||||
|
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
||||||
|
}
|
||||||
|
|
||||||
return nil, true, nil
|
return nil, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -712,14 +719,22 @@ func (h *Headscale) registerNodeForOIDCCallback(
|
||||||
machineKey *key.MachinePublic,
|
machineKey *key.MachinePublic,
|
||||||
expiry time.Time,
|
expiry time.Time,
|
||||||
) error {
|
) error {
|
||||||
if _, err := h.db.RegisterNodeFromAuthCallback(
|
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
// TODO(kradalby): find a better way to use the cache across modules
|
if _, err := db.RegisterNodeFromAuthCallback(
|
||||||
h.registrationCache,
|
// TODO(kradalby): find a better way to use the cache across modules
|
||||||
*machineKey,
|
tx,
|
||||||
user.Name,
|
h.registrationCache,
|
||||||
&expiry,
|
*machineKey,
|
||||||
util.RegisterMethodOIDC,
|
user.Name,
|
||||||
); err != nil {
|
&expiry,
|
||||||
|
util.RegisterMethodOIDC,
|
||||||
|
h.cfg.IPPrefixes,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
util.LogErr(err, "could not register node")
|
util.LogErr(err, "could not register node")
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
writer.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
|
@ -250,6 +250,21 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
|
||||||
if node.IPAddresses.InIPSet(expanded) {
|
if node.IPAddresses.InIPSet(expanded) {
|
||||||
dests = append(dests, dest)
|
dests = append(dests, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the node exposes routes, ensure they are note removed
|
||||||
|
// when the filters are reduced.
|
||||||
|
if node.Hostinfo != nil {
|
||||||
|
// TODO(kradalby): Evaluate if we should only keep
|
||||||
|
// the routes if the route is enabled. This will
|
||||||
|
// require database access in this part of the code.
|
||||||
|
if len(node.Hostinfo.RoutableIPs) > 0 {
|
||||||
|
for _, routableIP := range node.Hostinfo.RoutableIPs {
|
||||||
|
if expanded.ContainsPrefix(routableIP) {
|
||||||
|
dests = append(dests, dest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(dests) > 0 {
|
if len(dests) > 0 {
|
||||||
|
@ -890,32 +905,39 @@ func (pol *ACLPolicy) TagsOfNode(
|
||||||
validTags := make([]string, 0)
|
validTags := make([]string, 0)
|
||||||
invalidTags := make([]string, 0)
|
invalidTags := make([]string, 0)
|
||||||
|
|
||||||
|
// TODO(kradalby): Why is this sometimes nil? coming from tailNode?
|
||||||
|
if node == nil {
|
||||||
|
return validTags, invalidTags
|
||||||
|
}
|
||||||
|
|
||||||
validTagMap := make(map[string]bool)
|
validTagMap := make(map[string]bool)
|
||||||
invalidTagMap := make(map[string]bool)
|
invalidTagMap := make(map[string]bool)
|
||||||
for _, tag := range node.Hostinfo.RequestTags {
|
if node.Hostinfo != nil {
|
||||||
owners, err := expandOwnersFromTag(pol, tag)
|
for _, tag := range node.Hostinfo.RequestTags {
|
||||||
if errors.Is(err, ErrInvalidTag) {
|
owners, err := expandOwnersFromTag(pol, tag)
|
||||||
invalidTagMap[tag] = true
|
if errors.Is(err, ErrInvalidTag) {
|
||||||
|
invalidTagMap[tag] = true
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var found bool
|
var found bool
|
||||||
for _, owner := range owners {
|
for _, owner := range owners {
|
||||||
if node.User.Name == owner {
|
if node.User.Name == owner {
|
||||||
found = true
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if found {
|
||||||
|
validTagMap[tag] = true
|
||||||
|
} else {
|
||||||
|
invalidTagMap[tag] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if found {
|
for tag := range invalidTagMap {
|
||||||
validTagMap[tag] = true
|
invalidTags = append(invalidTags, tag)
|
||||||
} else {
|
}
|
||||||
invalidTagMap[tag] = true
|
for tag := range validTagMap {
|
||||||
|
validTags = append(validTags, tag)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
for tag := range invalidTagMap {
|
|
||||||
invalidTags = append(invalidTags, tag)
|
|
||||||
}
|
|
||||||
for tag := range validTagMap {
|
|
||||||
validTags = append(validTags, tag)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return validTags, invalidTags
|
return validTags, invalidTags
|
||||||
|
|
|
@ -1901,6 +1901,81 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{},
|
want: []tailcfg.FilterRule{},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "1604-subnet-routers-are-preserved",
|
||||||
|
pol: ACLPolicy{
|
||||||
|
Groups: Groups{
|
||||||
|
"group:admins": {"user1"},
|
||||||
|
},
|
||||||
|
ACLs: []ACL{
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"group:admins"},
|
||||||
|
Destinations: []string{"group:admins:*"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"group:admins"},
|
||||||
|
Destinations: []string{"10.33.0.0/16:*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
node: &types.Node{
|
||||||
|
IPAddresses: types.NodeAddresses{
|
||||||
|
netip.MustParseAddr("100.64.0.1"),
|
||||||
|
netip.MustParseAddr("fd7a:115c:a1e0::1"),
|
||||||
|
},
|
||||||
|
User: types.User{Name: "user1"},
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
RoutableIPs: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.33.0.0/16"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peers: types.Nodes{
|
||||||
|
&types.Node{
|
||||||
|
IPAddresses: types.NodeAddresses{
|
||||||
|
netip.MustParseAddr("100.64.0.2"),
|
||||||
|
netip.MustParseAddr("fd7a:115c:a1e0::2"),
|
||||||
|
},
|
||||||
|
User: types.User{Name: "user1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{
|
||||||
|
"100.64.0.1/32",
|
||||||
|
"100.64.0.2/32",
|
||||||
|
"fd7a:115c:a1e0::1/128",
|
||||||
|
"fd7a:115c:a1e0::2/128",
|
||||||
|
},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{
|
||||||
|
IP: "100.64.0.1/32",
|
||||||
|
Ports: tailcfg.PortRangeAny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: "fd7a:115c:a1e0::1/128",
|
||||||
|
Ports: tailcfg.PortRangeAny,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SrcIPs: []string{
|
||||||
|
"100.64.0.1/32",
|
||||||
|
"100.64.0.2/32",
|
||||||
|
"fd7a:115c:a1e0::1/128",
|
||||||
|
"fd7a:115c:a1e0::2/128",
|
||||||
|
},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{
|
||||||
|
IP: "10.33.0.0/16",
|
||||||
|
Ports: tailcfg.PortRangeAny,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|
|
@ -4,12 +4,15 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
xslices "golang.org/x/exp/slices"
|
xslices "golang.org/x/exp/slices"
|
||||||
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -125,6 +128,18 @@ func (h *Headscale) handlePoll(
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.ACLPolicy != nil {
|
||||||
|
// update routes with peer information
|
||||||
|
update, err := h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
|
||||||
|
if err != nil {
|
||||||
|
logErr(err, "Error running auto approved routes")
|
||||||
|
}
|
||||||
|
|
||||||
|
if update != nil {
|
||||||
|
sendUpdate = true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Services is mostly useful for discovery and not critical,
|
// Services is mostly useful for discovery and not critical,
|
||||||
|
@ -138,41 +153,68 @@ func (h *Headscale) handlePoll(
|
||||||
}
|
}
|
||||||
|
|
||||||
if sendUpdate {
|
if sendUpdate {
|
||||||
if err := h.db.NodeSave(node); err != nil {
|
if err := h.db.DB.Save(node).Error; err != nil {
|
||||||
logErr(err, "Failed to persist/update node in the database")
|
logErr(err, "Failed to persist/update node in the database")
|
||||||
http.Error(writer, "", http.StatusInternalServerError)
|
http.Error(writer, "", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send an update to all peers to propagate the new routes
|
||||||
|
// available.
|
||||||
stateUpdate := types.StateUpdate{
|
stateUpdate := types.StateUpdate{
|
||||||
Type: types.StatePeerChanged,
|
Type: types.StatePeerChanged,
|
||||||
ChangeNodes: types.Nodes{node},
|
ChangeNodes: types.Nodes{node},
|
||||||
Message: "called from handlePoll -> update -> new hostinfo",
|
Message: "called from handlePoll -> update -> new hostinfo",
|
||||||
}
|
}
|
||||||
if stateUpdate.Valid() {
|
if stateUpdate.Valid() {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-hostinfochange", node.Hostname)
|
||||||
h.nodeNotifier.NotifyWithIgnore(
|
h.nodeNotifier.NotifyWithIgnore(
|
||||||
|
ctx,
|
||||||
stateUpdate,
|
stateUpdate,
|
||||||
node.MachineKey.String())
|
node.MachineKey.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send an update to the node itself with to ensure it
|
||||||
|
// has an updated packetfilter allowing the new route
|
||||||
|
// if it is defined in the ACL.
|
||||||
|
selfUpdate := types.StateUpdate{
|
||||||
|
Type: types.StateSelfUpdate,
|
||||||
|
ChangeNodes: types.Nodes{node},
|
||||||
|
}
|
||||||
|
if selfUpdate.Valid() {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", node.Hostname)
|
||||||
|
h.nodeNotifier.NotifyByMachineKey(
|
||||||
|
ctx,
|
||||||
|
selfUpdate,
|
||||||
|
node.MachineKey)
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.NodeSave(node); err != nil {
|
if err := h.db.DB.Save(node).Error; err != nil {
|
||||||
logErr(err, "Failed to persist/update node in the database")
|
logErr(err, "Failed to persist/update node in the database")
|
||||||
http.Error(writer, "", http.StatusInternalServerError)
|
http.Error(writer, "", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Figure out why patch changes does
|
||||||
|
// not show up in output from `tailscale debug netmap`.
|
||||||
|
// stateUpdate := types.StateUpdate{
|
||||||
|
// Type: types.StatePeerChangedPatch,
|
||||||
|
// ChangePatches: []*tailcfg.PeerChange{&change},
|
||||||
|
// }
|
||||||
stateUpdate := types.StateUpdate{
|
stateUpdate := types.StateUpdate{
|
||||||
Type: types.StatePeerChangedPatch,
|
Type: types.StatePeerChanged,
|
||||||
ChangePatches: []*tailcfg.PeerChange{&change},
|
ChangeNodes: types.Nodes{node},
|
||||||
}
|
}
|
||||||
if stateUpdate.Valid() {
|
if stateUpdate.Valid() {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
|
||||||
h.nodeNotifier.NotifyWithIgnore(
|
h.nodeNotifier.NotifyWithIgnore(
|
||||||
|
ctx,
|
||||||
stateUpdate,
|
stateUpdate,
|
||||||
node.MachineKey.String())
|
node.MachineKey.String())
|
||||||
}
|
}
|
||||||
|
@ -228,7 +270,7 @@ func (h *Headscale) handlePoll(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.NodeSave(node); err != nil {
|
if err := h.db.DB.Save(node).Error; err != nil {
|
||||||
logErr(err, "Failed to persist/update node in the database")
|
logErr(err, "Failed to persist/update node in the database")
|
||||||
http.Error(writer, "", http.StatusInternalServerError)
|
http.Error(writer, "", http.StatusInternalServerError)
|
||||||
|
|
||||||
|
@ -265,7 +307,10 @@ func (h *Headscale) handlePoll(
|
||||||
// update ACLRules with peer informations (to update server tags if necessary)
|
// update ACLRules with peer informations (to update server tags if necessary)
|
||||||
if h.ACLPolicy != nil {
|
if h.ACLPolicy != nil {
|
||||||
// update routes with peer information
|
// update routes with peer information
|
||||||
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
|
// This state update is ignored as it will be sent
|
||||||
|
// as part of the whole node
|
||||||
|
// TODO(kradalby): figure out if that is actually correct
|
||||||
|
_, err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logErr(err, "Error running auto approved routes")
|
logErr(err, "Error running auto approved routes")
|
||||||
}
|
}
|
||||||
|
@ -301,11 +346,17 @@ func (h *Headscale) handlePoll(
|
||||||
Message: "called from handlePoll -> new node added",
|
Message: "called from handlePoll -> new node added",
|
||||||
}
|
}
|
||||||
if stateUpdate.Valid() {
|
if stateUpdate.Valid() {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "poll-newnode-peers", node.Hostname)
|
||||||
h.nodeNotifier.NotifyWithIgnore(
|
h.nodeNotifier.NotifyWithIgnore(
|
||||||
|
ctx,
|
||||||
stateUpdate,
|
stateUpdate,
|
||||||
node.MachineKey.String())
|
node.MachineKey.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(node.Routes) > 0 {
|
||||||
|
go h.pollFailoverRoutes(logErr, "new node", node)
|
||||||
|
}
|
||||||
|
|
||||||
// Set up the client stream
|
// Set up the client stream
|
||||||
h.pollNetMapStreamWG.Add(1)
|
h.pollNetMapStreamWG.Add(1)
|
||||||
defer h.pollNetMapStreamWG.Done()
|
defer h.pollNetMapStreamWG.Done()
|
||||||
|
@ -323,15 +374,9 @@ func (h *Headscale) handlePoll(
|
||||||
|
|
||||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, nodeNameContextKey, node.Hostname)
|
ctx, cancel := context.WithCancel(context.WithValue(ctx, nodeNameContextKey, node.Hostname))
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if len(node.Routes) > 0 {
|
|
||||||
go h.db.EnsureFailoverRouteIsAvailable(node)
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
logInfo("Waiting for update on stream channel")
|
logInfo("Waiting for update on stream channel")
|
||||||
select {
|
select {
|
||||||
|
@ -370,6 +415,17 @@ func (h *Headscale) handlePoll(
|
||||||
var data []byte
|
var data []byte
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
// Ensure the node object is updated, for example, there
|
||||||
|
// might have been a hostinfo update in a sidechannel
|
||||||
|
// which contains data needed to generate a map response.
|
||||||
|
node, err = h.db.GetNodeByMachineKey(node.MachineKey)
|
||||||
|
if err != nil {
|
||||||
|
logErr(err, "Could not get machine from db")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
startMapResp := time.Now()
|
||||||
switch update.Type {
|
switch update.Type {
|
||||||
case types.StateFullUpdate:
|
case types.StateFullUpdate:
|
||||||
logInfo("Sending Full MapResponse")
|
logInfo("Sending Full MapResponse")
|
||||||
|
@ -378,6 +434,7 @@ func (h *Headscale) handlePoll(
|
||||||
case types.StatePeerChanged:
|
case types.StatePeerChanged:
|
||||||
logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message))
|
logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message))
|
||||||
|
|
||||||
|
isConnectedMap := h.nodeNotifier.ConnectedMap()
|
||||||
for _, node := range update.ChangeNodes {
|
for _, node := range update.ChangeNodes {
|
||||||
// If a node is not reported to be online, it might be
|
// If a node is not reported to be online, it might be
|
||||||
// because the value is outdated, check with the notifier.
|
// because the value is outdated, check with the notifier.
|
||||||
|
@ -385,7 +442,7 @@ func (h *Headscale) handlePoll(
|
||||||
// this might be because it has announced itself, but not
|
// this might be because it has announced itself, but not
|
||||||
// reached the stage to actually create the notifier channel.
|
// reached the stage to actually create the notifier channel.
|
||||||
if node.IsOnline != nil && !*node.IsOnline {
|
if node.IsOnline != nil && !*node.IsOnline {
|
||||||
isOnline := h.nodeNotifier.IsConnected(node.MachineKey)
|
isOnline := isConnectedMap[node.MachineKey]
|
||||||
node.IsOnline = &isOnline
|
node.IsOnline = &isOnline
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -401,7 +458,7 @@ func (h *Headscale) handlePoll(
|
||||||
if len(update.ChangeNodes) == 1 {
|
if len(update.ChangeNodes) == 1 {
|
||||||
logInfo("Sending SelfUpdate MapResponse")
|
logInfo("Sending SelfUpdate MapResponse")
|
||||||
node = update.ChangeNodes[0]
|
node = update.ChangeNodes[0]
|
||||||
data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy)
|
data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy, types.SelfUpdateIdentifier)
|
||||||
} else {
|
} else {
|
||||||
logInfo("SelfUpdate contained too many nodes, this is likely a bug in the code, please report.")
|
logInfo("SelfUpdate contained too many nodes, this is likely a bug in the code, please report.")
|
||||||
}
|
}
|
||||||
|
@ -416,8 +473,11 @@ func (h *Headscale) handlePoll(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startMapResp).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished making map response")
|
||||||
|
|
||||||
// Only send update if there is change
|
// Only send update if there is change
|
||||||
if data != nil {
|
if data != nil {
|
||||||
|
startWrite := time.Now()
|
||||||
_, err = writer.Write(data)
|
_, err = writer.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logErr(err, "Could not write the map response")
|
logErr(err, "Could not write the map response")
|
||||||
|
@ -435,6 +495,7 @@ func (h *Headscale) handlePoll(
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished writing mapresp to node")
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -454,7 +515,7 @@ func (h *Headscale) handlePoll(
|
||||||
go h.updateNodeOnlineStatus(false, node)
|
go h.updateNodeOnlineStatus(false, node)
|
||||||
|
|
||||||
// Failover the node's routes if any.
|
// Failover the node's routes if any.
|
||||||
go h.db.FailoverNodeRoutesWithNotify(node)
|
go h.pollFailoverRoutes(logErr, "node closing connection", node)
|
||||||
|
|
||||||
// The connection has been closed, so we can stop polling.
|
// The connection has been closed, so we can stop polling.
|
||||||
return
|
return
|
||||||
|
@ -467,6 +528,22 @@ func (h *Headscale) handlePoll(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) pollFailoverRoutes(logErr func(error, string), where string, node *types.Node) {
|
||||||
|
update, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||||
|
return db.EnsureFailoverRouteIsAvailable(tx, h.nodeNotifier.ConnectedMap(), node)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logErr(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if update != nil && !update.Empty() && update.Valid() {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname)
|
||||||
|
h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.MachineKey.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
|
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
|
||||||
// about change in their online/offline status.
|
// about change in their online/offline status.
|
||||||
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
|
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
|
||||||
|
@ -486,10 +563,13 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if statusUpdate.Valid() {
|
if statusUpdate.Valid() {
|
||||||
h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String())
|
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
|
||||||
|
h.nodeNotifier.NotifyWithIgnore(ctx, statusUpdate, node.MachineKey.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.db.UpdateLastSeen(node)
|
err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
return db.UpdateLastSeen(tx, node.ID, *node.LastSeen)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Cannot update node LastSeen")
|
log.Error().Err(err).Msg("Cannot update node LastSeen")
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MinimumCapVersion tailcfg.CapabilityVersion = 56
|
MinimumCapVersion tailcfg.CapabilityVersion = 58
|
||||||
)
|
)
|
||||||
|
|
||||||
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
|
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
|
||||||
|
|
|
@ -41,8 +41,12 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||||
}
|
}
|
||||||
cfg := types.Config{
|
cfg := types.Config{
|
||||||
NoisePrivateKeyPath: tmpDir + "/noise_private.key",
|
NoisePrivateKeyPath: tmpDir + "/noise_private.key",
|
||||||
DBtype: "sqlite3",
|
Database: types.DatabaseConfig{
|
||||||
DBpath: tmpDir + "/headscale_test.db",
|
Type: "sqlite3",
|
||||||
|
Sqlite: types.SqliteConfig{
|
||||||
|
Path: tmpDir + "/headscale_test.db",
|
||||||
|
},
|
||||||
|
},
|
||||||
IPPrefixes: []netip.Prefix{
|
IPPrefixes: []netip.Prefix{
|
||||||
netip.MustParsePrefix("10.27.0.0/23"),
|
netip.MustParsePrefix("10.27.0.0/23"),
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,15 +1,23 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SelfUpdateIdentifier = "self-update"
|
||||||
|
DatabasePostgres = "postgres"
|
||||||
|
DatabaseSqlite = "sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||||
|
|
||||||
type IPPrefix netip.Prefix
|
type IPPrefix netip.Prefix
|
||||||
|
@ -150,7 +158,9 @@ func (su *StateUpdate) Valid() bool {
|
||||||
}
|
}
|
||||||
case StateSelfUpdate:
|
case StateSelfUpdate:
|
||||||
if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 {
|
if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 {
|
||||||
panic("Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node")
|
panic(
|
||||||
|
"Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
case StateDERPUpdated:
|
case StateDERPUpdated:
|
||||||
if su.DERPMap == nil {
|
if su.DERPMap == nil {
|
||||||
|
@ -160,3 +170,37 @@ func (su *StateUpdate) Valid() bool {
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Empty reports if there are any updates in the StateUpdate.
|
||||||
|
func (su *StateUpdate) Empty() bool {
|
||||||
|
switch su.Type {
|
||||||
|
case StatePeerChanged:
|
||||||
|
return len(su.ChangeNodes) == 0
|
||||||
|
case StatePeerChangedPatch:
|
||||||
|
return len(su.ChangePatches) == 0
|
||||||
|
case StatePeerRemoved:
|
||||||
|
return len(su.Removed) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate {
|
||||||
|
return StateUpdate{
|
||||||
|
Type: StatePeerChangedPatch,
|
||||||
|
ChangePatches: []*tailcfg.PeerChange{
|
||||||
|
{
|
||||||
|
NodeID: tailcfg.NodeID(nodeID),
|
||||||
|
KeyExpiry: &expiry,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
|
||||||
|
ctx2, _ := context.WithTimeout(
|
||||||
|
context.WithValue(context.WithValue(ctx, "hostname", hostname), "origin", origin),
|
||||||
|
3*time.Second,
|
||||||
|
)
|
||||||
|
return ctx2
|
||||||
|
}
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"github.com/prometheus/common/model"
|
"github.com/prometheus/common/model"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -20,6 +19,8 @@ import (
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -46,16 +47,9 @@ type Config struct {
|
||||||
Log LogConfig
|
Log LogConfig
|
||||||
DisableUpdateCheck bool
|
DisableUpdateCheck bool
|
||||||
|
|
||||||
DERP DERPConfig
|
Database DatabaseConfig
|
||||||
|
|
||||||
DBtype string
|
DERP DERPConfig
|
||||||
DBpath string
|
|
||||||
DBhost string
|
|
||||||
DBport int
|
|
||||||
DBname string
|
|
||||||
DBuser string
|
|
||||||
DBpass string
|
|
||||||
DBssl string
|
|
||||||
|
|
||||||
TLS TLSConfig
|
TLS TLSConfig
|
||||||
|
|
||||||
|
@ -77,6 +71,31 @@ type Config struct {
|
||||||
ACL ACLConfig
|
ACL ACLConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SqliteConfig struct {
|
||||||
|
Path string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PostgresConfig struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
Name string
|
||||||
|
User string
|
||||||
|
Pass string
|
||||||
|
Ssl string
|
||||||
|
MaxOpenConnections int
|
||||||
|
MaxIdleConnections int
|
||||||
|
ConnMaxIdleTimeSecs int
|
||||||
|
}
|
||||||
|
|
||||||
|
type DatabaseConfig struct {
|
||||||
|
// Type sets the database type, either "sqlite3" or "postgres"
|
||||||
|
Type string
|
||||||
|
Debug bool
|
||||||
|
|
||||||
|
Sqlite SqliteConfig
|
||||||
|
Postgres PostgresConfig
|
||||||
|
}
|
||||||
|
|
||||||
type TLSConfig struct {
|
type TLSConfig struct {
|
||||||
CertPath string
|
CertPath string
|
||||||
KeyPath string
|
KeyPath string
|
||||||
|
@ -111,16 +130,19 @@ type OIDCConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DERPConfig struct {
|
type DERPConfig struct {
|
||||||
ServerEnabled bool
|
ServerEnabled bool
|
||||||
ServerRegionID int
|
AutomaticallyAddEmbeddedDerpRegion bool
|
||||||
ServerRegionCode string
|
ServerRegionID int
|
||||||
ServerRegionName string
|
ServerRegionCode string
|
||||||
ServerPrivateKeyPath string
|
ServerRegionName string
|
||||||
STUNAddr string
|
ServerPrivateKeyPath string
|
||||||
URLs []url.URL
|
STUNAddr string
|
||||||
Paths []string
|
URLs []url.URL
|
||||||
AutoUpdate bool
|
Paths []string
|
||||||
UpdateFrequency time.Duration
|
AutoUpdate bool
|
||||||
|
UpdateFrequency time.Duration
|
||||||
|
IPv4 string
|
||||||
|
IPv6 string
|
||||||
}
|
}
|
||||||
|
|
||||||
type LogTailConfig struct {
|
type LogTailConfig struct {
|
||||||
|
@ -162,6 +184,19 @@ func LoadConfig(path string, isFile bool) error {
|
||||||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||||
viper.AutomaticEnv()
|
viper.AutomaticEnv()
|
||||||
|
|
||||||
|
viper.RegisterAlias("db_type", "database.type")
|
||||||
|
|
||||||
|
// SQLite aliases
|
||||||
|
viper.RegisterAlias("db_path", "database.sqlite.path")
|
||||||
|
|
||||||
|
// Postgres aliases
|
||||||
|
viper.RegisterAlias("db_host", "database.postgres.host")
|
||||||
|
viper.RegisterAlias("db_port", "database.postgres.port")
|
||||||
|
viper.RegisterAlias("db_name", "database.postgres.name")
|
||||||
|
viper.RegisterAlias("db_user", "database.postgres.user")
|
||||||
|
viper.RegisterAlias("db_pass", "database.postgres.pass")
|
||||||
|
viper.RegisterAlias("db_ssl", "database.postgres.ssl")
|
||||||
|
|
||||||
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
|
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
|
||||||
viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType)
|
viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType)
|
||||||
|
|
||||||
|
@ -173,6 +208,7 @@ func LoadConfig(path string, isFile bool) error {
|
||||||
|
|
||||||
viper.SetDefault("derp.server.enabled", false)
|
viper.SetDefault("derp.server.enabled", false)
|
||||||
viper.SetDefault("derp.server.stun.enabled", true)
|
viper.SetDefault("derp.server.stun.enabled", true)
|
||||||
|
viper.SetDefault("derp.server.automatically_add_embedded_derp_region", true)
|
||||||
|
|
||||||
viper.SetDefault("unix_socket", "/var/run/headscale/headscale.sock")
|
viper.SetDefault("unix_socket", "/var/run/headscale/headscale.sock")
|
||||||
viper.SetDefault("unix_socket_permission", "0o770")
|
viper.SetDefault("unix_socket_permission", "0o770")
|
||||||
|
@ -184,6 +220,10 @@ func LoadConfig(path string, isFile bool) error {
|
||||||
viper.SetDefault("cli.insecure", false)
|
viper.SetDefault("cli.insecure", false)
|
||||||
|
|
||||||
viper.SetDefault("db_ssl", false)
|
viper.SetDefault("db_ssl", false)
|
||||||
|
viper.SetDefault("database.postgres.ssl", false)
|
||||||
|
viper.SetDefault("database.postgres.max_open_conns", 10)
|
||||||
|
viper.SetDefault("database.postgres.max_idle_conns", 10)
|
||||||
|
viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600)
|
||||||
|
|
||||||
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
||||||
viper.SetDefault("oidc.strip_email_domain", true)
|
viper.SetDefault("oidc.strip_email_domain", true)
|
||||||
|
@ -262,7 +302,7 @@ func LoadConfig(path string, isFile bool) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if errorText != "" {
|
if errorText != "" {
|
||||||
//nolint
|
// nolint
|
||||||
return errors.New(strings.TrimSuffix(errorText, "\n"))
|
return errors.New(strings.TrimSuffix(errorText, "\n"))
|
||||||
} else {
|
} else {
|
||||||
return nil
|
return nil
|
||||||
|
@ -294,8 +334,14 @@ func GetDERPConfig() DERPConfig {
|
||||||
serverRegionCode := viper.GetString("derp.server.region_code")
|
serverRegionCode := viper.GetString("derp.server.region_code")
|
||||||
serverRegionName := viper.GetString("derp.server.region_name")
|
serverRegionName := viper.GetString("derp.server.region_name")
|
||||||
stunAddr := viper.GetString("derp.server.stun_listen_addr")
|
stunAddr := viper.GetString("derp.server.stun_listen_addr")
|
||||||
privateKeyPath := util.AbsolutePathFromConfigPath(viper.GetString("derp.server.private_key_path"))
|
privateKeyPath := util.AbsolutePathFromConfigPath(
|
||||||
|
viper.GetString("derp.server.private_key_path"),
|
||||||
|
)
|
||||||
|
ipv4 := viper.GetString("derp.server.ipv4")
|
||||||
|
ipv6 := viper.GetString("derp.server.ipv6")
|
||||||
|
automaticallyAddEmbeddedDerpRegion := viper.GetBool(
|
||||||
|
"derp.server.automatically_add_embedded_derp_region",
|
||||||
|
)
|
||||||
if serverEnabled && stunAddr == "" {
|
if serverEnabled && stunAddr == "" {
|
||||||
log.Fatal().
|
log.Fatal().
|
||||||
Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true")
|
Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true")
|
||||||
|
@ -318,20 +364,28 @@ func GetDERPConfig() DERPConfig {
|
||||||
|
|
||||||
paths := viper.GetStringSlice("derp.paths")
|
paths := viper.GetStringSlice("derp.paths")
|
||||||
|
|
||||||
|
if serverEnabled && !automaticallyAddEmbeddedDerpRegion && len(paths) == 0 {
|
||||||
|
log.Fatal().
|
||||||
|
Msg("Disabling derp.server.automatically_add_embedded_derp_region requires to configure the derp server in derp.paths")
|
||||||
|
}
|
||||||
|
|
||||||
autoUpdate := viper.GetBool("derp.auto_update_enabled")
|
autoUpdate := viper.GetBool("derp.auto_update_enabled")
|
||||||
updateFrequency := viper.GetDuration("derp.update_frequency")
|
updateFrequency := viper.GetDuration("derp.update_frequency")
|
||||||
|
|
||||||
return DERPConfig{
|
return DERPConfig{
|
||||||
ServerEnabled: serverEnabled,
|
ServerEnabled: serverEnabled,
|
||||||
ServerRegionID: serverRegionID,
|
ServerRegionID: serverRegionID,
|
||||||
ServerRegionCode: serverRegionCode,
|
ServerRegionCode: serverRegionCode,
|
||||||
ServerRegionName: serverRegionName,
|
ServerRegionName: serverRegionName,
|
||||||
ServerPrivateKeyPath: privateKeyPath,
|
ServerPrivateKeyPath: privateKeyPath,
|
||||||
STUNAddr: stunAddr,
|
STUNAddr: stunAddr,
|
||||||
URLs: urls,
|
URLs: urls,
|
||||||
Paths: paths,
|
Paths: paths,
|
||||||
AutoUpdate: autoUpdate,
|
AutoUpdate: autoUpdate,
|
||||||
UpdateFrequency: updateFrequency,
|
UpdateFrequency: updateFrequency,
|
||||||
|
IPv4: ipv4,
|
||||||
|
IPv6: ipv6,
|
||||||
|
AutomaticallyAddEmbeddedDerpRegion: automaticallyAddEmbeddedDerpRegion,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,6 +433,45 @@ func GetLogConfig() LogConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetDatabaseConfig() DatabaseConfig {
|
||||||
|
debug := viper.GetBool("database.debug")
|
||||||
|
|
||||||
|
type_ := viper.GetString("database.type")
|
||||||
|
|
||||||
|
switch type_ {
|
||||||
|
case DatabaseSqlite, DatabasePostgres:
|
||||||
|
break
|
||||||
|
case "sqlite":
|
||||||
|
type_ = "sqlite3"
|
||||||
|
default:
|
||||||
|
log.Fatal().
|
||||||
|
Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_)
|
||||||
|
}
|
||||||
|
|
||||||
|
return DatabaseConfig{
|
||||||
|
Type: type_,
|
||||||
|
Debug: debug,
|
||||||
|
Sqlite: SqliteConfig{
|
||||||
|
Path: util.AbsolutePathFromConfigPath(
|
||||||
|
viper.GetString("database.sqlite.path"),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
Postgres: PostgresConfig{
|
||||||
|
Host: viper.GetString("database.postgres.host"),
|
||||||
|
Port: viper.GetInt("database.postgres.port"),
|
||||||
|
Name: viper.GetString("database.postgres.name"),
|
||||||
|
User: viper.GetString("database.postgres.user"),
|
||||||
|
Pass: viper.GetString("database.postgres.pass"),
|
||||||
|
Ssl: viper.GetString("database.postgres.ssl"),
|
||||||
|
MaxOpenConnections: viper.GetInt("database.postgres.max_open_conns"),
|
||||||
|
MaxIdleConnections: viper.GetInt("database.postgres.max_idle_conns"),
|
||||||
|
ConnMaxIdleTimeSecs: viper.GetInt(
|
||||||
|
"database.postgres.conn_max_idle_time_secs",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
|
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
|
||||||
if viper.IsSet("dns_config") {
|
if viper.IsSet("dns_config") {
|
||||||
dnsConfig := &tailcfg.DNSConfig{}
|
dnsConfig := &tailcfg.DNSConfig{}
|
||||||
|
@ -580,7 +673,7 @@ func GetHeadscaleConfig() (*Config, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
oidcClientSecret = string(secretBytes)
|
oidcClientSecret = strings.TrimSpace(string(secretBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Config{
|
return &Config{
|
||||||
|
@ -607,14 +700,7 @@ func GetHeadscaleConfig() (*Config, error) {
|
||||||
"node_update_check_interval",
|
"node_update_check_interval",
|
||||||
),
|
),
|
||||||
|
|
||||||
DBtype: viper.GetString("db_type"),
|
Database: GetDatabaseConfig(),
|
||||||
DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")),
|
|
||||||
DBhost: viper.GetString("db_host"),
|
|
||||||
DBport: viper.GetInt("db_port"),
|
|
||||||
DBname: viper.GetString("db_name"),
|
|
||||||
DBuser: viper.GetString("db_user"),
|
|
||||||
DBpass: viper.GetString("db_pass"),
|
|
||||||
DBssl: viper.GetString("db_ssl"),
|
|
||||||
|
|
||||||
TLS: GetTLSConfig(),
|
TLS: GetTLSConfig(),
|
||||||
|
|
||||||
|
|
|
@ -383,7 +383,7 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri
|
||||||
// inform peers about smaller changes to the node.
|
// inform peers about smaller changes to the node.
|
||||||
// When a field is added to this function, remember to also add it to:
|
// When a field is added to this function, remember to also add it to:
|
||||||
// - node.ApplyPeerChange
|
// - node.ApplyPeerChange
|
||||||
// - logTracePeerChange in poll.go
|
// - logTracePeerChange in poll.go.
|
||||||
func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
|
func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
|
||||||
ret := tailcfg.PeerChange{
|
ret := tailcfg.PeerChange{
|
||||||
NodeID: tailcfg.NodeID(node.ID),
|
NodeID: tailcfg.NodeID(node.ID),
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/google/go-cmp/cmp/cmpopts"
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
@ -366,3 +367,110 @@ func TestPeerChangeFromMapRequest(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyPeerChange(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
nodeBefore Node
|
||||||
|
change *tailcfg.PeerChange
|
||||||
|
want Node
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "hostinfo-and-netinfo-not-exists",
|
||||||
|
nodeBefore: Node{},
|
||||||
|
change: &tailcfg.PeerChange{
|
||||||
|
DERPRegion: 1,
|
||||||
|
},
|
||||||
|
want: Node{
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
NetInfo: &tailcfg.NetInfo{
|
||||||
|
PreferredDERP: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostinfo-netinfo-not-exists",
|
||||||
|
nodeBefore: Node{
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
Hostname: "test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
change: &tailcfg.PeerChange{
|
||||||
|
DERPRegion: 3,
|
||||||
|
},
|
||||||
|
want: Node{
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
Hostname: "test",
|
||||||
|
NetInfo: &tailcfg.NetInfo{
|
||||||
|
PreferredDERP: 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostinfo-netinfo-exists-derp-set",
|
||||||
|
nodeBefore: Node{
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
Hostname: "test",
|
||||||
|
NetInfo: &tailcfg.NetInfo{
|
||||||
|
PreferredDERP: 999,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
change: &tailcfg.PeerChange{
|
||||||
|
DERPRegion: 2,
|
||||||
|
},
|
||||||
|
want: Node{
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
Hostname: "test",
|
||||||
|
NetInfo: &tailcfg.NetInfo{
|
||||||
|
PreferredDERP: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "endpoints-not-set",
|
||||||
|
nodeBefore: Node{},
|
||||||
|
change: &tailcfg.PeerChange{
|
||||||
|
Endpoints: []netip.AddrPort{
|
||||||
|
netip.MustParseAddrPort("8.8.8.8:88"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: Node{
|
||||||
|
Endpoints: []netip.AddrPort{
|
||||||
|
netip.MustParseAddrPort("8.8.8.8:88"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "endpoints-set",
|
||||||
|
nodeBefore: Node{
|
||||||
|
Endpoints: []netip.AddrPort{
|
||||||
|
netip.MustParseAddrPort("6.6.6.6:66"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
change: &tailcfg.PeerChange{
|
||||||
|
Endpoints: []netip.AddrPort{
|
||||||
|
netip.MustParseAddrPort("8.8.8.8:88"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: Node{
|
||||||
|
Endpoints: []netip.AddrPort{
|
||||||
|
netip.MustParseAddrPort("8.8.8.8:88"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.nodeBefore.ApplyPeerChange(tt.change)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.want, tt.nodeBefore, util.Comparers...); diff != "" {
|
||||||
|
t.Errorf("Patch unexpected result (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
@ -22,12 +21,13 @@ type User struct {
|
||||||
|
|
||||||
func (n *User) TailscaleUser() *tailcfg.User {
|
func (n *User) TailscaleUser() *tailcfg.User {
|
||||||
user := tailcfg.User{
|
user := tailcfg.User{
|
||||||
ID: tailcfg.UserID(n.ID),
|
ID: tailcfg.UserID(n.ID),
|
||||||
LoginName: n.Name,
|
LoginName: n.Name,
|
||||||
DisplayName: n.Name,
|
DisplayName: n.Name,
|
||||||
|
// TODO(kradalby): See if we can fill in Gravatar here
|
||||||
ProfilePicURL: "",
|
ProfilePicURL: "",
|
||||||
Logins: []tailcfg.LoginID{},
|
Logins: []tailcfg.LoginID{},
|
||||||
Created: time.Time{},
|
Created: n.CreatedAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
return &user
|
return &user
|
||||||
|
@ -35,9 +35,10 @@ func (n *User) TailscaleUser() *tailcfg.User {
|
||||||
|
|
||||||
func (n *User) TailscaleLogin() *tailcfg.Login {
|
func (n *User) TailscaleLogin() *tailcfg.Login {
|
||||||
login := tailcfg.Login{
|
login := tailcfg.Login{
|
||||||
ID: tailcfg.LoginID(n.ID),
|
ID: tailcfg.LoginID(n.ID),
|
||||||
LoginName: n.Name,
|
LoginName: n.Name,
|
||||||
DisplayName: n.Name,
|
DisplayName: n.Name,
|
||||||
|
// TODO(kradalby): See if we can fill in Gravatar here
|
||||||
ProfilePicURL: "",
|
ProfilePicURL: "",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,10 @@ var IPComparer = cmp.Comparer(func(x, y netip.Addr) bool {
|
||||||
return x.Compare(y) == 0
|
return x.Compare(y) == 0
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var AddrPortComparer = cmp.Comparer(func(x, y netip.AddrPort) bool {
|
||||||
|
return x == y
|
||||||
|
})
|
||||||
|
|
||||||
var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool {
|
var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool {
|
||||||
return x.String() == y.String()
|
return x.String() == y.String()
|
||||||
})
|
})
|
||||||
|
@ -28,5 +32,5 @@ var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
|
||||||
})
|
})
|
||||||
|
|
||||||
var Comparers []cmp.Option = []cmp.Option{
|
var Comparers []cmp.Option = []cmp.Option{
|
||||||
IPComparer, PrefixComparer, MkeyComparer, NkeyComparer, DkeyComparer,
|
IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer,
|
||||||
}
|
}
|
||||||
|
|
|
@ -265,6 +265,8 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
})
|
})
|
||||||
|
@ -325,6 +327,8 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
})
|
})
|
||||||
|
|
|
@ -53,6 +53,8 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
})
|
})
|
||||||
|
@ -90,6 +92,8 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
})
|
})
|
||||||
|
|
|
@ -1240,7 +1240,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
assert.Contains(t, listAll[4].GetGivenName(), "node-5")
|
assert.Contains(t, listAll[4].GetGivenName(), "node-5")
|
||||||
|
|
||||||
for idx := 0; idx < 3; idx++ {
|
for idx := 0; idx < 3; idx++ {
|
||||||
_, err := headscale.Execute(
|
res, err := headscale.Execute(
|
||||||
[]string{
|
[]string{
|
||||||
"headscale",
|
"headscale",
|
||||||
"nodes",
|
"nodes",
|
||||||
|
@ -1251,6 +1251,8 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Contains(t, res, "Node renamed")
|
||||||
}
|
}
|
||||||
|
|
||||||
var listAllAfterRename []v1.Node
|
var listAllAfterRename []v1.Node
|
||||||
|
|
|
@ -33,20 +33,23 @@ func TestDERPServerScenario(t *testing.T) {
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
"user1": len(MustTestVersions),
|
"user1": 10,
|
||||||
|
// "user1": len(MustTestVersions),
|
||||||
}
|
}
|
||||||
|
|
||||||
headscaleConfig := map[string]string{}
|
headscaleConfig := map[string]string{
|
||||||
headscaleConfig["HEADSCALE_DERP_URLS"] = ""
|
"HEADSCALE_DERP_URLS": "",
|
||||||
headscaleConfig["HEADSCALE_DERP_SERVER_ENABLED"] = "true"
|
"HEADSCALE_DERP_SERVER_ENABLED": "true",
|
||||||
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_ID"] = "999"
|
"HEADSCALE_DERP_SERVER_REGION_ID": "999",
|
||||||
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_CODE"] = "headscale"
|
"HEADSCALE_DERP_SERVER_REGION_CODE": "headscale",
|
||||||
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP"
|
"HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP",
|
||||||
headscaleConfig["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478"
|
"HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478",
|
||||||
headscaleConfig["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key"
|
"HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key",
|
||||||
// Envknob for enabling DERP debug logs
|
|
||||||
headscaleConfig["DERP_DEBUG_LOGS"] = "true"
|
// Envknob for enabling DERP debug logs
|
||||||
headscaleConfig["DERP_PROBER_DEBUG_LOGS"] = "true"
|
"DERP_DEBUG_LOGS": "true",
|
||||||
|
"DERP_PROBER_DEBUG_LOGS": "true",
|
||||||
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(
|
err = scenario.CreateHeadscaleEnv(
|
||||||
spec,
|
spec,
|
||||||
|
@ -67,6 +70,8 @@ func TestDERPServerScenario(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
||||||
assertNoErrListFQDN(t, err)
|
assertNoErrListFQDN(t, err)
|
||||||
|
|
||||||
|
|
|
@ -26,12 +26,34 @@ func TestPingAllByIP(t *testing.T) {
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
// TODO(kradalby): it does not look like the user thing works, only second
|
||||||
|
// get created? maybe only when many?
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
"user1": len(MustTestVersions),
|
"user1": len(MustTestVersions),
|
||||||
"user2": len(MustTestVersions),
|
"user2": len(MustTestVersions),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyip"))
|
headscaleConfig := map[string]string{
|
||||||
|
"HEADSCALE_DERP_URLS": "",
|
||||||
|
"HEADSCALE_DERP_SERVER_ENABLED": "true",
|
||||||
|
"HEADSCALE_DERP_SERVER_REGION_ID": "999",
|
||||||
|
"HEADSCALE_DERP_SERVER_REGION_CODE": "headscale",
|
||||||
|
"HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP",
|
||||||
|
"HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478",
|
||||||
|
"HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key",
|
||||||
|
|
||||||
|
// Envknob for enabling DERP debug logs
|
||||||
|
"DERP_DEBUG_LOGS": "true",
|
||||||
|
"DERP_PROBER_DEBUG_LOGS": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(spec,
|
||||||
|
[]tsic.Option{},
|
||||||
|
hsic.WithTestName("pingallbyip"),
|
||||||
|
hsic.WithConfigEnv(headscaleConfig),
|
||||||
|
hsic.WithTLS(),
|
||||||
|
hsic.WithHostnameAsServerURL(),
|
||||||
|
)
|
||||||
assertNoErrHeadscaleEnv(t, err)
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
allClients, err := scenario.ListTailscaleClients()
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
|
@ -43,6 +65,46 @@ func TestPingAllByIP(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
|
return x.String()
|
||||||
|
})
|
||||||
|
|
||||||
|
success := pingAllHelper(t, allClients, allAddrs)
|
||||||
|
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPingAllByIPPublicDERP(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
scenario, err := NewScenario()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
spec := map[string]int{
|
||||||
|
"user1": len(MustTestVersions),
|
||||||
|
"user2": len(MustTestVersions),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(spec,
|
||||||
|
[]tsic.Option{},
|
||||||
|
hsic.WithTestName("pingallbyippubderp"),
|
||||||
|
)
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
|
assertNoErrListClients(t, err)
|
||||||
|
|
||||||
|
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||||
|
assertNoErrListClientIPs(t, err)
|
||||||
|
|
||||||
|
err = scenario.WaitForTailscaleSync()
|
||||||
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
})
|
})
|
||||||
|
@ -73,6 +135,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
||||||
for _, client := range allClients {
|
for _, client := range allClients {
|
||||||
ips, err := client.IPs()
|
ips, err := client.IPs()
|
||||||
|
@ -112,6 +176,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
allClients, err = scenario.ListTailscaleClients()
|
allClients, err = scenario.ListTailscaleClients()
|
||||||
assertNoErrListClients(t, err)
|
assertNoErrListClients(t, err)
|
||||||
|
|
||||||
|
@ -263,6 +329,8 @@ func TestPingAllByHostname(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
||||||
assertNoErrListFQDN(t, err)
|
assertNoErrListFQDN(t, err)
|
||||||
|
|
||||||
|
@ -320,9 +388,13 @@ func TestTaildrop(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err)
|
t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"}
|
curlCommand := []string{
|
||||||
|
"curl",
|
||||||
|
"--unix-socket",
|
||||||
|
"/var/run/tailscale/tailscaled.sock",
|
||||||
|
"http://local-tailscaled.sock/localapi/v0/file-targets",
|
||||||
|
}
|
||||||
err = retry(10, 1*time.Second, func() error {
|
err = retry(10, 1*time.Second, func() error {
|
||||||
result, _, err := client.Execute(curlCommand)
|
result, _, err := client.Execute(curlCommand)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -339,13 +411,23 @@ func TestTaildrop(t *testing.T) {
|
||||||
for _, ft := range fts {
|
for _, ft := range fts {
|
||||||
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
|
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr)
|
return fmt.Errorf(
|
||||||
|
"client %s does not have all its peers as FileTargets, got %d, want: %d\n%s",
|
||||||
|
client.Hostname(),
|
||||||
|
len(fts),
|
||||||
|
len(allClients)-1,
|
||||||
|
ftStr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err)
|
t.Errorf(
|
||||||
|
"failed to query localapi for filetarget on %s, err: %s",
|
||||||
|
client.Hostname(),
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -457,6 +539,8 @@ func TestResolveMagicDNS(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
// Poor mans cache
|
// Poor mans cache
|
||||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||||
assertNoErrListFQDN(t, err)
|
assertNoErrListFQDN(t, err)
|
||||||
|
@ -525,6 +609,8 @@ func TestExpireNode(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
})
|
})
|
||||||
|
@ -546,7 +632,7 @@ func TestExpireNode(t *testing.T) {
|
||||||
// TODO(kradalby): This is Headscale specific and would not play nicely
|
// TODO(kradalby): This is Headscale specific and would not play nicely
|
||||||
// with other implementations of the ControlServer interface
|
// with other implementations of the ControlServer interface
|
||||||
result, err := headscale.Execute([]string{
|
result, err := headscale.Execute([]string{
|
||||||
"headscale", "nodes", "expire", "--identifier", "0", "--output", "json",
|
"headscale", "nodes", "expire", "--identifier", "1", "--output", "json",
|
||||||
})
|
})
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
@ -577,16 +663,38 @@ func TestExpireNode(t *testing.T) {
|
||||||
assertNotNil(t, peerStatus.Expired)
|
assertNotNil(t, peerStatus.Expired)
|
||||||
assert.NotNil(t, peerStatus.KeyExpiry)
|
assert.NotNil(t, peerStatus.KeyExpiry)
|
||||||
|
|
||||||
t.Logf("node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry)
|
t.Logf(
|
||||||
|
"node %q should have a key expire before %s, was %s",
|
||||||
|
peerStatus.HostName,
|
||||||
|
now.String(),
|
||||||
|
peerStatus.KeyExpiry,
|
||||||
|
)
|
||||||
if peerStatus.KeyExpiry != nil {
|
if peerStatus.KeyExpiry != nil {
|
||||||
assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry)
|
assert.Truef(
|
||||||
|
t,
|
||||||
|
peerStatus.KeyExpiry.Before(now),
|
||||||
|
"node %q should have a key expire before %s, was %s",
|
||||||
|
peerStatus.HostName,
|
||||||
|
now.String(),
|
||||||
|
peerStatus.KeyExpiry,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Truef(t, peerStatus.Expired, "node %q should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired)
|
assert.Truef(
|
||||||
|
t,
|
||||||
|
peerStatus.Expired,
|
||||||
|
"node %q should be expired, expired is %v",
|
||||||
|
peerStatus.HostName,
|
||||||
|
peerStatus.Expired,
|
||||||
|
)
|
||||||
|
|
||||||
_, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()})
|
_, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()})
|
||||||
if !strings.Contains(stderr, "node key has expired") {
|
if !strings.Contains(stderr, "node key has expired") {
|
||||||
t.Errorf("expected to be unable to ping expired host %q from %q", node.GetName(), client.Hostname())
|
t.Errorf(
|
||||||
|
"expected to be unable to ping expired host %q from %q",
|
||||||
|
node.GetName(),
|
||||||
|
client.Hostname(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey)
|
t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey)
|
||||||
|
@ -598,7 +706,7 @@ func TestExpireNode(t *testing.T) {
|
||||||
|
|
||||||
// NeedsLogin means that the node has understood that it is no longer
|
// NeedsLogin means that the node has understood that it is no longer
|
||||||
// valid.
|
// valid.
|
||||||
assert.Equal(t, "NeedsLogin", status.BackendState)
|
assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -627,6 +735,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
})
|
})
|
||||||
|
@ -691,7 +801,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
||||||
assert.Truef(
|
assert.Truef(
|
||||||
t,
|
t,
|
||||||
lastSeen.After(lastSeenThreshold),
|
lastSeen.After(lastSeenThreshold),
|
||||||
"lastSeen (%v) was not %s after the threshold (%v)",
|
"node (%s) lastSeen (%v) was not %s after the threshold (%v)",
|
||||||
|
node.GetName(),
|
||||||
lastSeen,
|
lastSeen,
|
||||||
keepAliveInterval,
|
keepAliveInterval,
|
||||||
lastSeenThreshold,
|
lastSeenThreshold,
|
||||||
|
|
|
@ -110,8 +110,8 @@ func DefaultConfigEnv() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
"HEADSCALE_LOG_LEVEL": "trace",
|
"HEADSCALE_LOG_LEVEL": "trace",
|
||||||
"HEADSCALE_ACL_POLICY_PATH": "",
|
"HEADSCALE_ACL_POLICY_PATH": "",
|
||||||
"HEADSCALE_DB_TYPE": "sqlite3",
|
"HEADSCALE_DATABASE_TYPE": "sqlite",
|
||||||
"HEADSCALE_DB_PATH": "/tmp/integration_test_db.sqlite3",
|
"HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3",
|
||||||
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
|
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
|
||||||
"HEADSCALE_NODE_UPDATE_CHECK_INTERVAL": "10s",
|
"HEADSCALE_NODE_UPDATE_CHECK_INTERVAL": "10s",
|
||||||
"HEADSCALE_IP_PREFIXES": "fd7a:115c:a1e0::/48 100.64.0.0/10",
|
"HEADSCALE_IP_PREFIXES": "fd7a:115c:a1e0::/48 100.64.0.0/10",
|
||||||
|
|
|
@ -9,10 +9,15 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"tailscale.com/types/ipproto"
|
||||||
|
"tailscale.com/wgengine/filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This test is both testing the routes command and the propagation of
|
// This test is both testing the routes command and the propagation of
|
||||||
|
@ -83,9 +88,9 @@ func TestEnablingRoutes(t *testing.T) {
|
||||||
assert.Len(t, routes, 3)
|
assert.Len(t, routes, 3)
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
assert.Equal(t, route.GetAdvertised(), true)
|
assert.Equal(t, true, route.GetAdvertised())
|
||||||
assert.Equal(t, route.GetEnabled(), false)
|
assert.Equal(t, false, route.GetEnabled())
|
||||||
assert.Equal(t, route.GetIsPrimary(), false)
|
assert.Equal(t, false, route.GetIsPrimary())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that no routes has been sent to the client,
|
// Verify that no routes has been sent to the client,
|
||||||
|
@ -130,9 +135,9 @@ func TestEnablingRoutes(t *testing.T) {
|
||||||
assert.Len(t, enablingRoutes, 3)
|
assert.Len(t, enablingRoutes, 3)
|
||||||
|
|
||||||
for _, route := range enablingRoutes {
|
for _, route := range enablingRoutes {
|
||||||
assert.Equal(t, route.GetAdvertised(), true)
|
assert.Equal(t, true, route.GetAdvertised())
|
||||||
assert.Equal(t, route.GetEnabled(), true)
|
assert.Equal(t, true, route.GetEnabled())
|
||||||
assert.Equal(t, route.GetIsPrimary(), true)
|
assert.Equal(t, true, route.GetIsPrimary())
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
|
@ -186,6 +191,8 @@ func TestEnablingRoutes(t *testing.T) {
|
||||||
})
|
})
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
var disablingRoutes []*v1.Route
|
var disablingRoutes []*v1.Route
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
headscale,
|
headscale,
|
||||||
|
@ -204,16 +211,14 @@ func TestEnablingRoutes(t *testing.T) {
|
||||||
assert.Equal(t, true, route.GetAdvertised())
|
assert.Equal(t, true, route.GetAdvertised())
|
||||||
|
|
||||||
if route.GetId() == routeToBeDisabled.GetId() {
|
if route.GetId() == routeToBeDisabled.GetId() {
|
||||||
assert.Equal(t, route.GetEnabled(), false)
|
assert.Equal(t, false, route.GetEnabled())
|
||||||
assert.Equal(t, route.GetIsPrimary(), false)
|
assert.Equal(t, false, route.GetIsPrimary())
|
||||||
} else {
|
} else {
|
||||||
assert.Equal(t, route.GetEnabled(), true)
|
assert.Equal(t, true, route.GetEnabled())
|
||||||
assert.Equal(t, route.GetIsPrimary(), true)
|
assert.Equal(t, true, route.GetIsPrimary())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
|
|
||||||
// Verify that the clients can see the new routes
|
// Verify that the clients can see the new routes
|
||||||
for _, client := range allClients {
|
for _, client := range allClients {
|
||||||
status, err := client.Status()
|
status, err := client.Status()
|
||||||
|
@ -289,7 +294,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
// advertise HA route on node 1 and 2
|
// advertise HA route on node 1 and 2
|
||||||
// ID 1 will be primary
|
// ID 1 will be primary
|
||||||
// ID 2 will be secondary
|
// ID 2 will be secondary
|
||||||
for _, client := range allClients {
|
for _, client := range allClients[:2] {
|
||||||
status, err := client.Status()
|
status, err := client.Status()
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
@ -301,6 +306,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
}
|
}
|
||||||
_, _, err = client.Execute(command)
|
_, _, err = client.Execute(command)
|
||||||
assertNoErrf(t, "failed to advertise route: %s", err)
|
assertNoErrf(t, "failed to advertise route: %s", err)
|
||||||
|
} else {
|
||||||
|
t.Fatalf("failed to find route for Node %s (id: %s)", status.Self.HostName, status.Self.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -323,6 +330,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
|
|
||||||
|
t.Logf("initial routes %#v", routes)
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
assert.Equal(t, true, route.GetAdvertised())
|
assert.Equal(t, true, route.GetAdvertised())
|
||||||
assert.Equal(t, false, route.GetEnabled())
|
assert.Equal(t, false, route.GetEnabled())
|
||||||
|
@ -639,6 +648,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
assert.Len(t, routesAfterDisabling1, 2)
|
assert.Len(t, routesAfterDisabling1, 2)
|
||||||
|
|
||||||
|
t.Logf("routes after disabling1 %#v", routesAfterDisabling1)
|
||||||
|
|
||||||
// Node 1 is not primary
|
// Node 1 is not primary
|
||||||
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
|
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
|
||||||
assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled())
|
assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled())
|
||||||
|
@ -778,3 +789,413 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnableDisableAutoApprovedRoute(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
expectedRoutes := "172.0.0.0/24"
|
||||||
|
|
||||||
|
user := "enable-disable-routing"
|
||||||
|
|
||||||
|
scenario, err := NewScenario()
|
||||||
|
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||||
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
spec := map[string]int{
|
||||||
|
user: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:approve"})}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy(
|
||||||
|
&policy.ACLPolicy{
|
||||||
|
ACLs: []policy.ACL{
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"*"},
|
||||||
|
Destinations: []string{"*:*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
TagOwners: map[string][]string{
|
||||||
|
"tag:approve": {user},
|
||||||
|
},
|
||||||
|
AutoApprovers: policy.AutoApprovers{
|
||||||
|
Routes: map[string][]string{
|
||||||
|
expectedRoutes: {"tag:approve"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
))
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
|
assertNoErrListClients(t, err)
|
||||||
|
|
||||||
|
err = scenario.WaitForTailscaleSync()
|
||||||
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErrGetHeadscale(t, err)
|
||||||
|
|
||||||
|
subRouter1 := allClients[0]
|
||||||
|
|
||||||
|
// Initially advertise route
|
||||||
|
command := []string{
|
||||||
|
"tailscale",
|
||||||
|
"set",
|
||||||
|
"--advertise-routes=" + expectedRoutes,
|
||||||
|
}
|
||||||
|
_, _, err = subRouter1.Execute(command)
|
||||||
|
assertNoErrf(t, "failed to advertise route: %s", err)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
|
||||||
|
var routes []*v1.Route
|
||||||
|
err = executeAndUnmarshal(
|
||||||
|
headscale,
|
||||||
|
[]string{
|
||||||
|
"headscale",
|
||||||
|
"routes",
|
||||||
|
"list",
|
||||||
|
"--output",
|
||||||
|
"json",
|
||||||
|
},
|
||||||
|
&routes,
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, routes, 1)
|
||||||
|
|
||||||
|
// All routes should be auto approved and enabled
|
||||||
|
assert.Equal(t, true, routes[0].GetAdvertised())
|
||||||
|
assert.Equal(t, true, routes[0].GetEnabled())
|
||||||
|
assert.Equal(t, true, routes[0].GetIsPrimary())
|
||||||
|
|
||||||
|
// Stop advertising route
|
||||||
|
command = []string{
|
||||||
|
"tailscale",
|
||||||
|
"set",
|
||||||
|
"--advertise-routes=",
|
||||||
|
}
|
||||||
|
_, _, err = subRouter1.Execute(command)
|
||||||
|
assertNoErrf(t, "failed to remove advertised route: %s", err)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
|
||||||
|
var notAdvertisedRoutes []*v1.Route
|
||||||
|
err = executeAndUnmarshal(
|
||||||
|
headscale,
|
||||||
|
[]string{
|
||||||
|
"headscale",
|
||||||
|
"routes",
|
||||||
|
"list",
|
||||||
|
"--output",
|
||||||
|
"json",
|
||||||
|
},
|
||||||
|
¬AdvertisedRoutes,
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, notAdvertisedRoutes, 1)
|
||||||
|
|
||||||
|
// Route is no longer advertised
|
||||||
|
assert.Equal(t, false, notAdvertisedRoutes[0].GetAdvertised())
|
||||||
|
assert.Equal(t, false, notAdvertisedRoutes[0].GetEnabled())
|
||||||
|
assert.Equal(t, true, notAdvertisedRoutes[0].GetIsPrimary())
|
||||||
|
|
||||||
|
// Advertise route again
|
||||||
|
command = []string{
|
||||||
|
"tailscale",
|
||||||
|
"set",
|
||||||
|
"--advertise-routes=" + expectedRoutes,
|
||||||
|
}
|
||||||
|
_, _, err = subRouter1.Execute(command)
|
||||||
|
assertNoErrf(t, "failed to advertise route: %s", err)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
|
||||||
|
var reAdvertisedRoutes []*v1.Route
|
||||||
|
err = executeAndUnmarshal(
|
||||||
|
headscale,
|
||||||
|
[]string{
|
||||||
|
"headscale",
|
||||||
|
"routes",
|
||||||
|
"list",
|
||||||
|
"--output",
|
||||||
|
"json",
|
||||||
|
},
|
||||||
|
&reAdvertisedRoutes,
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, reAdvertisedRoutes, 1)
|
||||||
|
|
||||||
|
// All routes should be auto approved and enabled
|
||||||
|
assert.Equal(t, true, reAdvertisedRoutes[0].GetAdvertised())
|
||||||
|
assert.Equal(t, true, reAdvertisedRoutes[0].GetEnabled())
|
||||||
|
assert.Equal(t, true, reAdvertisedRoutes[0].GetIsPrimary())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSubnetRouteACL verifies that Subnet routes are distributed
|
||||||
|
// as expected when ACLs are activated.
|
||||||
|
// It implements the issue from
|
||||||
|
// https://github.com/juanfont/headscale/issues/1604
|
||||||
|
func TestSubnetRouteACL(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
user := "subnet-route-acl"
|
||||||
|
|
||||||
|
scenario, err := NewScenario()
|
||||||
|
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||||
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
spec := map[string]int{
|
||||||
|
user: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy(
|
||||||
|
&policy.ACLPolicy{
|
||||||
|
Groups: policy.Groups{
|
||||||
|
"group:admins": {user},
|
||||||
|
},
|
||||||
|
ACLs: []policy.ACL{
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"group:admins"},
|
||||||
|
Destinations: []string{"group:admins:*"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"group:admins"},
|
||||||
|
Destinations: []string{"10.33.0.0/16:*"},
|
||||||
|
},
|
||||||
|
// {
|
||||||
|
// Action: "accept",
|
||||||
|
// Sources: []string{"group:admins"},
|
||||||
|
// Destinations: []string{"0.0.0.0/0:*"},
|
||||||
|
// },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
))
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
|
assertNoErrListClients(t, err)
|
||||||
|
|
||||||
|
err = scenario.WaitForTailscaleSync()
|
||||||
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErrGetHeadscale(t, err)
|
||||||
|
|
||||||
|
expectedRoutes := map[string]string{
|
||||||
|
"1": "10.33.0.0/16",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort nodes by ID
|
||||||
|
sort.SliceStable(allClients, func(i, j int) bool {
|
||||||
|
statusI, err := allClients[i].Status()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
statusJ, err := allClients[j].Status()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return statusI.Self.ID < statusJ.Self.ID
|
||||||
|
})
|
||||||
|
|
||||||
|
subRouter1 := allClients[0]
|
||||||
|
|
||||||
|
client := allClients[1]
|
||||||
|
|
||||||
|
// advertise HA route on node 1 and 2
|
||||||
|
// ID 1 will be primary
|
||||||
|
// ID 2 will be secondary
|
||||||
|
for _, client := range allClients {
|
||||||
|
status, err := client.Status()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
if route, ok := expectedRoutes[string(status.Self.ID)]; ok {
|
||||||
|
command := []string{
|
||||||
|
"tailscale",
|
||||||
|
"set",
|
||||||
|
"--advertise-routes=" + route,
|
||||||
|
}
|
||||||
|
_, _, err = client.Execute(command)
|
||||||
|
assertNoErrf(t, "failed to advertise route: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.WaitForTailscaleSync()
|
||||||
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
var routes []*v1.Route
|
||||||
|
err = executeAndUnmarshal(
|
||||||
|
headscale,
|
||||||
|
[]string{
|
||||||
|
"headscale",
|
||||||
|
"routes",
|
||||||
|
"list",
|
||||||
|
"--output",
|
||||||
|
"json",
|
||||||
|
},
|
||||||
|
&routes,
|
||||||
|
)
|
||||||
|
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, routes, 1)
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
assert.Equal(t, true, route.GetAdvertised())
|
||||||
|
assert.Equal(t, false, route.GetEnabled())
|
||||||
|
assert.Equal(t, false, route.GetIsPrimary())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that no routes has been sent to the client,
|
||||||
|
// they are not yet enabled.
|
||||||
|
for _, client := range allClients {
|
||||||
|
status, err := client.Status()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
for _, peerKey := range status.Peers() {
|
||||||
|
peerStatus := status.Peer[peerKey]
|
||||||
|
|
||||||
|
assert.Nil(t, peerStatus.PrimaryRoutes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable all routes
|
||||||
|
for _, route := range routes {
|
||||||
|
_, err = headscale.Execute(
|
||||||
|
[]string{
|
||||||
|
"headscale",
|
||||||
|
"routes",
|
||||||
|
"enable",
|
||||||
|
"--route",
|
||||||
|
strconv.Itoa(int(route.GetId())),
|
||||||
|
})
|
||||||
|
assertNoErr(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
var enablingRoutes []*v1.Route
|
||||||
|
err = executeAndUnmarshal(
|
||||||
|
headscale,
|
||||||
|
[]string{
|
||||||
|
"headscale",
|
||||||
|
"routes",
|
||||||
|
"list",
|
||||||
|
"--output",
|
||||||
|
"json",
|
||||||
|
},
|
||||||
|
&enablingRoutes,
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, enablingRoutes, 1)
|
||||||
|
|
||||||
|
// Node 1 has active route
|
||||||
|
assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
|
||||||
|
assert.Equal(t, true, enablingRoutes[0].GetEnabled())
|
||||||
|
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary())
|
||||||
|
|
||||||
|
// Verify that the client has routes from the primary machine
|
||||||
|
srs1, _ := subRouter1.Status()
|
||||||
|
|
||||||
|
clientStatus, err := client.Status()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
|
||||||
|
|
||||||
|
assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
|
||||||
|
|
||||||
|
t.Logf("subnet1 has following routes: %v", srs1PeerStatus.PrimaryRoutes.AsSlice())
|
||||||
|
assert.Len(t, srs1PeerStatus.PrimaryRoutes.AsSlice(), 1)
|
||||||
|
assert.Contains(
|
||||||
|
t,
|
||||||
|
srs1PeerStatus.PrimaryRoutes.AsSlice(),
|
||||||
|
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
|
||||||
|
)
|
||||||
|
|
||||||
|
clientNm, err := client.Netmap()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
wantClientFilter := []filter.Match{
|
||||||
|
{
|
||||||
|
IPProto: []ipproto.Proto{
|
||||||
|
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
|
||||||
|
},
|
||||||
|
Srcs: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("100.64.0.1/32"),
|
||||||
|
netip.MustParsePrefix("100.64.0.2/32"),
|
||||||
|
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||||
|
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||||
|
},
|
||||||
|
Dsts: []filter.NetPortRange{
|
||||||
|
{
|
||||||
|
Net: netip.MustParsePrefix("100.64.0.2/32"),
|
||||||
|
Ports: filter.PortRange{0, 0xffff},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Net: netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||||
|
Ports: filter.PortRange{0, 0xffff},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Caps: []filter.CapMatch{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.PrefixComparer); diff != "" {
|
||||||
|
t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
subnetNm, err := subRouter1.Netmap()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
wantSubnetFilter := []filter.Match{
|
||||||
|
{
|
||||||
|
IPProto: []ipproto.Proto{
|
||||||
|
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
|
||||||
|
},
|
||||||
|
Srcs: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("100.64.0.1/32"),
|
||||||
|
netip.MustParsePrefix("100.64.0.2/32"),
|
||||||
|
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||||
|
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||||
|
},
|
||||||
|
Dsts: []filter.NetPortRange{
|
||||||
|
{
|
||||||
|
Net: netip.MustParsePrefix("100.64.0.1/32"),
|
||||||
|
Ports: filter.PortRange{0, 0xffff},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Net: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||||
|
Ports: filter.PortRange{0, 0xffff},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Caps: []filter.CapMatch{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IPProto: []ipproto.Proto{
|
||||||
|
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
|
||||||
|
},
|
||||||
|
Srcs: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("100.64.0.1/32"),
|
||||||
|
netip.MustParsePrefix("100.64.0.2/32"),
|
||||||
|
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||||
|
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||||
|
},
|
||||||
|
Dsts: []filter.NetPortRange{
|
||||||
|
{
|
||||||
|
Net: netip.MustParsePrefix("10.33.0.0/16"),
|
||||||
|
Ports: filter.PortRange{0, 0xffff},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Caps: []filter.CapMatch{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.PrefixComparer); diff != "" {
|
||||||
|
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -56,8 +56,8 @@ var (
|
||||||
"1.44": true, // CapVer: 63
|
"1.44": true, // CapVer: 63
|
||||||
"1.42": true, // CapVer: 61
|
"1.42": true, // CapVer: 61
|
||||||
"1.40": true, // CapVer: 61
|
"1.40": true, // CapVer: 61
|
||||||
"1.38": true, // CapVer: 58
|
"1.38": true, // Oldest supported version, CapVer: 58
|
||||||
"1.36": true, // Oldest supported version, CapVer: 56
|
"1.36": false, // CapVer: 56
|
||||||
"1.34": false, // CapVer: 51
|
"1.34": false, // CapVer: 51
|
||||||
"1.32": false, // CapVer: 46
|
"1.32": false, // CapVer: 46
|
||||||
"1.30": false,
|
"1.30": false,
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
"tailscale.com/ipn/ipnstate"
|
"tailscale.com/ipn/ipnstate"
|
||||||
|
"tailscale.com/net/netcheck"
|
||||||
|
"tailscale.com/types/netmap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint
|
// nolint
|
||||||
|
@ -26,6 +28,8 @@ type TailscaleClient interface {
|
||||||
IPs() ([]netip.Addr, error)
|
IPs() ([]netip.Addr, error)
|
||||||
FQDN() (string, error)
|
FQDN() (string, error)
|
||||||
Status() (*ipnstate.Status, error)
|
Status() (*ipnstate.Status, error)
|
||||||
|
Netmap() (*netmap.NetworkMap, error)
|
||||||
|
Netcheck() (*netcheck.Report, error)
|
||||||
WaitForNeedsLogin() error
|
WaitForNeedsLogin() error
|
||||||
WaitForRunning() error
|
WaitForRunning() error
|
||||||
WaitForPeers(expected int) error
|
WaitForPeers(expected int) error
|
||||||
|
|
|
@ -17,6 +17,8 @@ import (
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
"github.com/ory/dockertest/v3/docker"
|
"github.com/ory/dockertest/v3/docker"
|
||||||
"tailscale.com/ipn/ipnstate"
|
"tailscale.com/ipn/ipnstate"
|
||||||
|
"tailscale.com/net/netcheck"
|
||||||
|
"tailscale.com/types/netmap"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -519,6 +521,53 @@ func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) {
|
||||||
return &status, err
|
return &status, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
|
||||||
|
// Only works with Tailscale 1.56.1 and newer.
|
||||||
|
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
|
||||||
|
command := []string{
|
||||||
|
"tailscale",
|
||||||
|
"debug",
|
||||||
|
"netmap",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, stderr, err := t.Execute(command)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("stderr: %s\n", stderr)
|
||||||
|
return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var nm netmap.NetworkMap
|
||||||
|
err = json.Unmarshal([]byte(result), &nm)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &nm, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance.
|
||||||
|
func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) {
|
||||||
|
command := []string{
|
||||||
|
"tailscale",
|
||||||
|
"netcheck",
|
||||||
|
"--format=json",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, stderr, err := t.Execute(command)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("stderr: %s\n", stderr)
|
||||||
|
return nil, fmt.Errorf("failed to execute tailscale debug netcheck command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var nm netcheck.Report
|
||||||
|
err = json.Unmarshal([]byte(result), &nm)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal tailscale netcheck: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &nm, err
|
||||||
|
}
|
||||||
|
|
||||||
// FQDN returns the FQDN as a string of the Tailscale instance.
|
// FQDN returns the FQDN as a string of the Tailscale instance.
|
||||||
func (t *TailscaleInContainer) FQDN() (string, error) {
|
func (t *TailscaleInContainer) FQDN() (string, error) {
|
||||||
if t.fqdn != "" {
|
if t.fqdn != "" {
|
||||||
|
@ -623,12 +672,22 @@ func (t *TailscaleInContainer) WaitForPeers(expected int) error {
|
||||||
len(peers),
|
len(peers),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
// Verify that the peers of a given node is Online
|
||||||
|
// has a hostname and a DERP relay.
|
||||||
for _, peerKey := range peers {
|
for _, peerKey := range peers {
|
||||||
peer := status.Peer[peerKey]
|
peer := status.Peer[peerKey]
|
||||||
|
|
||||||
if !peer.Online {
|
if !peer.Online {
|
||||||
return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)
|
return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peer.HostName == "" {
|
||||||
|
return fmt.Errorf("[%s] peer count correct, but %s does not have a Hostname", t.hostname, peer.HostName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.Relay == "" {
|
||||||
|
return fmt.Errorf("[%s] peer count correct, but %s does not have a DERP", t.hostname, peer.HostName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"tailscale.com/util/cmpver"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -83,7 +85,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
err := client.Ping(addr, opts...)
|
err := client.Ping(addr, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
|
t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
|
||||||
} else {
|
} else {
|
||||||
success++
|
success++
|
||||||
}
|
}
|
||||||
|
@ -120,6 +122,148 @@ func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string)
|
||||||
return success
|
return success
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// assertClientsState validates the status and netmap of a list of
|
||||||
|
// clients for the general case of all to all connectivity.
|
||||||
|
func assertClientsState(t *testing.T, clients []TailscaleClient) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for _, client := range clients {
|
||||||
|
assertValidStatus(t, client)
|
||||||
|
assertValidNetmap(t, client)
|
||||||
|
assertValidNetcheck(t, client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertValidNetmap asserts that the netmap of a client has all
|
||||||
|
// the minimum required fields set to a known working config for
|
||||||
|
// the general case. Fields are checked on self, then all peers.
|
||||||
|
// This test is not suitable for ACL/partial connection tests.
|
||||||
|
// This test can only be run on clients from 1.56.1. It will
|
||||||
|
// automatically pass all clients below that and is safe to call
|
||||||
|
// for all versions.
|
||||||
|
func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if cmpver.Compare("1.56.1", client.Version()) <= 0 ||
|
||||||
|
!strings.Contains(client.Hostname(), "unstable") ||
|
||||||
|
!strings.Contains(client.Hostname(), "head") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
netmap, err := client.Netmap()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getting netmap for %q: %s", client.Hostname(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
|
||||||
|
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
|
||||||
|
assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
|
||||||
|
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
|
||||||
|
|
||||||
|
assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname())
|
||||||
|
|
||||||
|
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
|
||||||
|
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
|
||||||
|
assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
|
||||||
|
|
||||||
|
for _, peer := range netmap.Peers {
|
||||||
|
assert.NotEqualf(t, "127.3.3.40:0", peer.DERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.DERP())
|
||||||
|
|
||||||
|
assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
|
||||||
|
if hi := peer.Hostinfo(); hi.Valid() {
|
||||||
|
assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
|
||||||
|
|
||||||
|
// Netinfo is not always set
|
||||||
|
assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
|
||||||
|
if ni := hi.NetInfo(); ni.Valid() {
|
||||||
|
assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
|
||||||
|
assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
|
||||||
|
assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
|
||||||
|
|
||||||
|
assert.Truef(t, *peer.Online(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
|
||||||
|
|
||||||
|
assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
|
||||||
|
assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
|
||||||
|
assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertValidStatus asserts that the status of a client has all
|
||||||
|
// the minimum required fields set to a known working config for
|
||||||
|
// the general case. Fields are checked on self, then all peers.
|
||||||
|
// This test is not suitable for ACL/partial connection tests.
|
||||||
|
func assertValidStatus(t *testing.T, client TailscaleClient) {
|
||||||
|
t.Helper()
|
||||||
|
status, err := client.Status()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname())
|
||||||
|
assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname())
|
||||||
|
assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname())
|
||||||
|
|
||||||
|
assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname())
|
||||||
|
|
||||||
|
// This seem to not appear until version 1.56
|
||||||
|
if status.Self.AllowedIPs != nil {
|
||||||
|
assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname())
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname())
|
||||||
|
|
||||||
|
assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname())
|
||||||
|
|
||||||
|
assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname())
|
||||||
|
|
||||||
|
// This isnt really relevant for Self as it wont be in its own socket/wireguard.
|
||||||
|
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
|
||||||
|
// assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname())
|
||||||
|
|
||||||
|
for _, peer := range status.Peer {
|
||||||
|
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
|
||||||
|
assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname())
|
||||||
|
assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname())
|
||||||
|
|
||||||
|
assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname())
|
||||||
|
|
||||||
|
// This seem to not appear until version 1.56
|
||||||
|
if peer.AllowedIPs != nil {
|
||||||
|
assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addrs does not seem to appear in the status from peers.
|
||||||
|
// assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname())
|
||||||
|
|
||||||
|
assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname())
|
||||||
|
|
||||||
|
assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname())
|
||||||
|
assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname())
|
||||||
|
|
||||||
|
// TODO(kradalby): InEngine is only true when a proper tunnel is set up,
|
||||||
|
// there might be some interesting stuff to test here in the future.
|
||||||
|
// assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertValidNetcheck(t *testing.T, client TailscaleClient) {
|
||||||
|
t.Helper()
|
||||||
|
report, err := client.Netcheck()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname())
|
||||||
|
}
|
||||||
|
|
||||||
func isSelfClient(client TailscaleClient, addr string) bool {
|
func isSelfClient(client TailscaleClient, addr string) bool {
|
||||||
if addr == client.Hostname() {
|
if addr == client.Hostname() {
|
||||||
return true
|
return true
|
||||||
|
@ -152,7 +296,7 @@ func isCI() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func dockertestMaxWait() time.Duration {
|
func dockertestMaxWait() time.Duration {
|
||||||
wait := 60 * time.Second //nolint
|
wait := 120 * time.Second //nolint
|
||||||
|
|
||||||
if isCI() {
|
if isCI() {
|
||||||
wait = 300 * time.Second //nolint
|
wait = 300 * time.Second //nolint
|
||||||
|
|
Loading…
Reference in a new issue