Merge upstream/main into fork

This commit is contained in:
Tyler Beckman 2024-02-11 21:48:13 -07:00
commit cc430af3f9
Signed by: Ty
GPG key ID: 2813440C772555A4
52 changed files with 3272 additions and 1415 deletions

View file

@ -50,3 +50,16 @@ instead of filing a bug report.
## To Reproduce
<!-- Steps to reproduce the behavior. -->
## Logs and attachments
<!-- Please attach files with:
- Client netmap dump (see below)
- ACL configuration
- Headscale configuration
Dump the netmap of tailscale clients:
`tailscale debug netmap > DESCRIPTIVE_NAME.json`
Please provide information describing the netmap, which client, which headscale version etc.
-->

View file

@ -0,0 +1,67 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestEnableDisableAutoApprovedRoute
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestEnableDisableAutoApprovedRoute:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestEnableDisableAutoApprovedRoute
uses: Wandalen/wretry.action@master
if: steps.changed-files.outputs.any_changed == 'true'
with:
attempt_limit: 5
command: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestEnableDisableAutoApprovedRoute$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -0,0 +1,67 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestPingAllByIPPublicDERP
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestPingAllByIPPublicDERP:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestPingAllByIPPublicDERP
uses: Wandalen/wretry.action@master
if: steps.changed-files.outputs.any_changed == 'true'
with:
attempt_limit: 5
command: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestPingAllByIPPublicDERP$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -0,0 +1,67 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestSubnetRouteACL
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestSubnetRouteACL:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestSubnetRouteACL
uses: Wandalen/wretry.action@master
if: steps.changed-files.outputs.any_changed == 'true'
with:
attempt_limit: 5
command: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestSubnetRouteACL$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -26,7 +26,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/
- Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473)
- API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553)
- Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611)
- The latest supported client is 1.36
- The latest supported client is 1.38
- Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564)
- If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url.
- Embedded DERP server requires a private key [#1611](https://github.com/juanfont/headscale/pull/1611)
@ -34,17 +34,21 @@ after improving the test harness as part of adopting [#1460](https://github.com/
### Changes
Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644)
Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287)
SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460)
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259)
Add `oidc.groups_claim`, `oidc.email_claim`, and `oidc.username_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)
- Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644)
- Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484)
- SSH support [#1487](https://github.com/juanfont/headscale/pull/1487)
- State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492)
- Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460)
- Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
- Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
- Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
- Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259)
- Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565)
- Change the structure of database configuration, see [config-example.yaml](./config-example.yaml) for the new structure. [#1700](https://github.com/juanfont/headscale/pull/1700)
- Old structure is now considered deprecated and will be removed in the future.
- Adds additional configuration for PostgreSQL for setting max open, idle conection and idle connection lifetime.
- Allow use of the username OIDC claim [#1287](https://github.com/juanfont/headscale/pull/1287)
- Add `oidc.groups_claim`, `oidc.email_claim`, and `oidc.username_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)
## 0.22.3 (2023-05-12)

View file

@ -6,25 +6,11 @@ import (
"github.com/efekarakus/termcolor"
"github.com/juanfont/headscale/cmd/headscale/cli"
"github.com/pkg/profile"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func main() {
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
err := os.MkdirAll(profilePath, os.ModePerm)
if err != nil {
log.Fatal().Err(err).Msg("failed to create profiling directory")
}
defer profile.Start(profile.ProfilePath(profilePath)).Stop()
} else {
defer profile.Start().Stop()
}
}
var colors bool
switch l := termcolor.SupportLevel(os.Stderr); l {
case termcolor.Level16M:

View file

@ -58,8 +58,10 @@ func (*Suite) TestConfigFileLoading(c *check.C) {
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite")
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
c.Assert(viper.GetString("database.type"), check.Equals, "sqlite")
c.Assert(viper.GetString("database.sqlite.path"), check.Equals, "/var/lib/headscale/db.sqlite")
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
@ -101,7 +103,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080")
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite")
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")

View file

@ -94,6 +94,16 @@ derp:
#
private_key_path: /var/lib/headscale/derp_server_private.key
# This flag can be used, so the DERP map entry for the embedded DERP server is not written automatically,
# it enables the creation of your very own DERP map entry using a locally available file with the parameter DERP.paths
# If you enable the DERP server and set this to false, it is required to add the DERP server to the DERP map using DERP.paths
automatically_add_embedded_derp_region: true
# For better connection stability (especially when using an Exit-Node and DNS is not working),
# it is possible to optionall add the public IPv4 and IPv6 address to the Derp-Map using:
ipv4: 1.2.3.4
ipv6: 2001:db8::1
# List of externally available DERP maps encoded in JSON
urls:
- https://controlplane.tailscale.com/derpmap/default
@ -128,24 +138,28 @@ ephemeral_node_inactivity_timeout: 30m
# In case of doubts, do not touch the default 10s.
node_update_check_interval: 10s
# SQLite config
db_type: sqlite3
database:
type: sqlite
# For production:
db_path: /var/lib/headscale/db.sqlite
# SQLite config
sqlite:
path: /var/lib/headscale/db.sqlite
# # Postgres config
# If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank.
# db_type: postgres
# db_host: localhost
# db_port: 5432
# db_name: headscale
# db_user: foo
# db_pass: bar
# # Postgres config
# postgres:
# # If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank.
# host: localhost
# port: 5432
# name: headscale
# user: foo
# pass: bar
# max_open_conns: 10
# max_idle_conns: 10
# conn_max_idle_time_secs: 3600
# If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need
# in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1.
# db_ssl: false
# # If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need
# # in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1.
# ssl: false
### TLS configuration
#

View file

@ -18,6 +18,7 @@ You can set these using the Windows Registry Editor:
Or via the following Powershell commands (right click Powershell icon and select "Run as administrator"):
```
New-Item -Path "HKLM:\SOFTWARE\Tailscale IPN"
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value https://YOUR-HEADSCALE-URL
```

View file

@ -12,7 +12,6 @@ import (
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
@ -33,6 +32,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog"
"github.com/pkg/profile"
"github.com/prometheus/client_golang/prometheus/promhttp"
zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@ -48,6 +48,7 @@ import (
"google.golang.org/grpc/peer"
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/status"
"gorm.io/gorm"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
@ -61,7 +62,7 @@ var (
"unknown value for Lets Encrypt challenge type",
)
errEmptyInitialDERPMap = errors.New(
"initial DERPMap is empty, Headscale requries at least one entry",
"initial DERPMap is empty, Headscale requires at least one entry",
)
)
@ -116,37 +117,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
}
var dbString string
switch cfg.DBtype {
case db.Postgres:
dbString = fmt.Sprintf(
"host=%s dbname=%s user=%s",
cfg.DBhost,
cfg.DBname,
cfg.DBuser,
)
if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil {
if !sslEnabled {
dbString += " sslmode=disable"
}
} else {
dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl)
}
if cfg.DBport != 0 {
dbString += fmt.Sprintf(" port=%d", cfg.DBport)
}
if cfg.DBpass != "" {
dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
}
case db.Sqlite:
dbString = cfg.DBpath
default:
return nil, errUnsupportedDatabase
}
registrationCache := cache.New(
registerCacheExpiration,
registerCacheCleanup,
@ -154,8 +124,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
app := Headscale{
cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
noisePrivateKey: noisePrivateKey,
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
@ -163,9 +131,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
}
database, err := db.NewHeadscaleDatabase(
cfg.DBtype,
dbString,
app.dbDebug,
cfg.Database,
app.nodeNotifier,
cfg.IPPrefixes,
cfg.BaseDomain)
@ -234,8 +200,23 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
var update types.StateUpdate
var changed bool
for range ticker.C {
h.db.ExpireEphemeralNodes(h.cfg.EphemeralNodeInactivityTimeout)
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
continue
}
if changed && update.Valid() {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, update)
}
}
}
@ -246,9 +227,24 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
ticker := time.NewTicker(interval)
lastCheck := time.Unix(0, 0)
var update types.StateUpdate
var changed bool
for range ticker.C {
lastCheck = h.db.ExpireExpiredNodes(lastCheck)
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring nodes")
continue
}
log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes")
if changed && update.Valid() {
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
}
}
}
@ -268,7 +264,7 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
case <-ticker.C:
log.Info().Msg("Fetching DERPMap updates")
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
if h.cfg.DERP.ServerEnabled {
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
region, _ := h.DERPServer.GenerateRegion()
h.DERPMap.Regions[region.RegionID] = &region
}
@ -278,7 +274,8 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
DERPMap: h.DERPMap,
}
if stateUpdate.Valid() {
h.nodeNotifier.NotifyAll(stateUpdate)
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
}
}
}
@ -485,6 +482,19 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
// Serve launches a GIN server with the Headscale API.
func (h *Headscale) Serve() error {
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
err := os.MkdirAll(profilePath, os.ModePerm)
if err != nil {
log.Fatal().Err(err).Msg("failed to create profiling directory")
}
defer profile.Start(profile.ProfilePath(profilePath)).Stop()
} else {
defer profile.Start().Stop()
}
}
var err error
// Fetch an initial DERP Map before we start serving
@ -501,7 +511,9 @@ func (h *Headscale) Serve() error {
return err
}
if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
h.DERPMap.Regions[region.RegionID] = &region
}
go h.DERPServer.ServeSTUN()
}
@ -708,14 +720,16 @@ func (h *Headscale) Serve() error {
var tailsqlContext context.Context
if tailsqlEnabled {
if h.cfg.DBtype != db.Sqlite {
log.Fatal().Str("type", h.cfg.DBtype).Msgf("tailsql only support %q", db.Sqlite)
if h.cfg.Database.Type != types.DatabaseSqlite {
log.Fatal().
Str("type", h.cfg.Database.Type).
Msgf("tailsql only support %q", types.DatabaseSqlite)
}
if tailsqlTSKey == "" {
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
}
tailsqlContext = context.Background()
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.DBpath)
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path)
}
// Handle common process-killing signals so we can gracefully shut down:
@ -751,7 +765,8 @@ func (h *Headscale) Serve() error {
Str("path", aclPath).
Msg("ACL policy successfully reloaded, notifying nodes of change")
h.nodeNotifier.NotifyAll(types.StateUpdate{
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}

View file

@ -1,6 +1,7 @@
package hscontrol
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -8,6 +9,7 @@ import (
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
@ -199,6 +201,19 @@ func (h *Headscale) handleRegister(
return
}
// When logged out and reauthenticating with OIDC, the OldNodeKey is not passed, but the NodeKey has changed
if node.NodeKey.String() != registerRequest.NodeKey.String() &&
registerRequest.OldNodeKey.IsZero() && !node.IsExpired() {
h.handleNodeKeyRefresh(
writer,
registerRequest,
*node,
machineKey,
)
return
}
if registerRequest.Followup != "" {
select {
case <-req.Context().Done():
@ -230,8 +245,6 @@ func (h *Headscale) handleRegister(
// handleAuthKey contains the logic to manage auth key client registration
// When using Noise, the machineKey is Zero.
//
// TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKey(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
@ -298,6 +311,9 @@ func (h *Headscale) handleAuthKey(
nodeKey := registerRequest.NodeKey
var update types.StateUpdate
var mkey key.MachinePublic
// retrieve node information if it exist
// The error is not important, because if it does not
// exist, then this is a new node and we will move
@ -311,7 +327,7 @@ func (h *Headscale) handleAuthKey(
node.NodeKey = nodeKey
node.AuthKeyID = uint(pak.ID)
err := h.db.NodeSetExpiry(node, registerRequest.Expiry)
err := h.db.NodeSetExpiry(node.ID, registerRequest.Expiry)
if err != nil {
log.Error().
Caller().
@ -322,10 +338,13 @@ func (h *Headscale) handleAuthKey(
return
}
mkey = node.MachineKey
update = types.StateUpdateExpire(node.ID, registerRequest.Expiry)
aclTags := pak.Proto().GetAclTags()
if len(aclTags) > 0 {
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
err = h.db.SetTags(node, aclTags)
err = h.db.SetTags(node.ID, aclTags)
if err != nil {
log.Error().
@ -357,6 +376,7 @@ func (h *Headscale) handleAuthKey(
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
UserID: pak.User.ID,
User: pak.User,
MachineKey: machineKey,
RegisterMethod: util.RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
@ -380,9 +400,18 @@ func (h *Headscale) handleAuthKey(
return
}
mkey = node.MachineKey
update = types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from auth.handleAuthKey",
}
}
err = h.db.UsePreAuthKey(pak)
err = h.db.DB.Transaction(func(tx *gorm.DB) error {
return db.UsePreAuthKey(tx, pak)
})
if err != nil {
log.Error().
Caller().
@ -424,6 +453,13 @@ func (h *Headscale) handleAuthKey(
Caller().
Err(err).
Msg("Failed to write response")
return
}
// TODO(kradalby): if notifying after register make sense.
if update.Valid() {
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String())
}
log.Info().
@ -489,7 +525,7 @@ func (h *Headscale) handleNodeLogOut(
Msg("Client requested logout")
now := time.Now()
err := h.db.NodeSetExpiry(&node, now)
err := h.db.NodeSetExpiry(node.ID, now)
if err != nil {
log.Error().
Caller().
@ -500,17 +536,10 @@ func (h *Headscale) handleNodeLogOut(
return
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &now,
},
},
}
stateUpdate := types.StateUpdateExpire(node.ID, now)
if stateUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
resp.AuthURL = ""
@ -541,7 +570,7 @@ func (h *Headscale) handleNodeLogOut(
}
if node.IsEphemeral() {
err = h.db.DeleteNode(&node)
err = h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap())
if err != nil {
log.Error().
Err(err).
@ -549,6 +578,15 @@ func (h *Headscale) handleNodeLogOut(
Msg("Cannot delete ephemeral node from the database")
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
}
return
}
@ -620,7 +658,9 @@ func (h *Headscale) handleNodeKeyRefresh(
Str("node", node.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh")
err := h.db.NodeSetNodeKey(&node, registerRequest.NodeKey)
err := h.db.DB.Transaction(func(tx *gorm.DB) error {
return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey)
})
if err != nil {
log.Error().
Caller().

View file

@ -13,16 +13,23 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"gorm.io/gorm"
)
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) {
return getAvailableIPs(rx, hsdb.ipPrefixes)
})
}
func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) {
var ips types.NodeAddresses
var err error
for _, ipPrefix := range hsdb.ipPrefixes {
for _, ipPrefix := range ipPrefixes {
var ip *netip.Addr
ip, err = hsdb.getAvailableIP(ipPrefix)
ip, err = getAvailableIP(rx, ipPrefix)
if err != nil {
return ips, err
}
@ -32,8 +39,8 @@ func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) {
return ips, err
}
func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
usedIps, err := hsdb.getUsedIPs()
func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) {
usedIps, err := getUsedIPs(rx)
if err != nil {
return nil, err
}
@ -64,12 +71,12 @@ func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, erro
}
}
func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) {
func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) {
// FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet.
var addressesSlices []string
hsdb.db.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices)
var ips netipx.IPSetBuilder
for _, slice := range addressesSlices {

View file

@ -7,10 +7,16 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
"gopkg.in/check.v1"
"gorm.io/gorm"
)
func (s *Suite) TestGetAvailableIp(c *check.C) {
ips, err := db.getAvailableIPs()
tx := db.DB.Begin()
defer tx.Rollback()
ips, err := getAvailableIPs(tx, []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
})
c.Assert(err, check.IsNil)
@ -30,7 +36,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
@ -41,10 +47,13 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
AuthKeyID: uint(pak.ID),
IPAddresses: ips,
}
db.db.Save(&node)
usedIps, err := db.getUsedIPs()
db.Write(func(tx *gorm.DB) error {
return tx.Save(&node).Error
})
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
return getUsedIPs(rx)
})
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
@ -63,19 +72,23 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
}
func (s *Suite) TestGetMultiIp(c *check.C) {
user, err := db.CreateUser("test-ip-multi")
user, err := db.CreateUser("test-ip")
c.Assert(err, check.IsNil)
ipPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
}
for index := 1; index <= 350; index++ {
db.ipAllocationMutex.Lock()
tx := db.DB.Begin()
ips, err := db.getAvailableIPs()
ips, err := getAvailableIPs(tx, ipPrefixes)
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = getNode(tx, "test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
@ -86,12 +99,13 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
AuthKeyID: uint(pak.ID),
IPAddresses: ips,
}
db.db.Save(&node)
db.ipAllocationMutex.Unlock()
tx.Save(&node)
c.Assert(tx.Commit().Error, check.IsNil)
}
usedIps, err := db.getUsedIPs()
usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) {
return getUsedIPs(rx)
})
c.Assert(err, check.IsNil)
expected0 := netip.MustParseAddr("10.27.0.1")
@ -162,7 +176,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
node := types.Node{
@ -172,7 +186,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
ips2, err := db.getAvailableIPs()
c.Assert(err, check.IsNil)

View file

@ -22,9 +22,6 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
func (hsdb *HSDatabase) CreateAPIKey(
expiration *time.Time,
) (string, *types.APIKey, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
if err != nil {
return "", nil, err
@ -49,7 +46,7 @@ func (hsdb *HSDatabase) CreateAPIKey(
Expiration: expiration,
}
if err := hsdb.db.Save(&key).Error; err != nil {
if err := hsdb.DB.Save(&key).Error; err != nil {
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
}
@ -58,11 +55,8 @@ func (hsdb *HSDatabase) CreateAPIKey(
// ListAPIKeys returns the list of ApiKeys for a user.
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
keys := []types.APIKey{}
if err := hsdb.db.Find(&keys).Error; err != nil {
if err := hsdb.DB.Find(&keys).Error; err != nil {
return nil, err
}
@ -71,11 +65,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
// GetAPIKey returns a ApiKey for a given key.
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
key := types.APIKey{}
if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil {
if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil {
return nil, result.Error
}
@ -84,11 +75,8 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
// GetAPIKeyByID returns a ApiKey for a given id.
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
key := types.APIKey{}
if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
return nil, result.Error
}
@ -98,10 +86,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
// does not exist.
func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if result := hsdb.db.Unscoped().Delete(key); result.Error != nil {
if result := hsdb.DB.Unscoped().Delete(key); result.Error != nil {
return result.Error
}
@ -110,10 +95,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
// ExpireAPIKey marks a ApiKey as expired.
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
return err
}
@ -121,9 +103,6 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
}
func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
prefix, hash, found := strings.Cut(keyStr, ".")
if !found {
return false, ErrAPIKeyFailedToParse

View file

@ -6,24 +6,20 @@ import (
"errors"
"fmt"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/glebarez/sqlite"
"github.com/go-gormigrate/gormigrate/v2"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const (
Postgres = "postgres"
Sqlite = "sqlite3"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)
var errDatabaseNotSupported = errors.New("database type not supported")
@ -36,12 +32,7 @@ type KV struct {
}
type HSDatabase struct {
db *gorm.DB
notifier *notifier.Notifier
mu sync.RWMutex
ipAllocationMutex sync.Mutex
DB *gorm.DB
ipPrefixes []netip.Prefix
baseDomain string
@ -50,18 +41,20 @@ type HSDatabase struct {
// TODO(kradalby): assemble this struct from toptions or something typed
// rather than arguments.
func NewHeadscaleDatabase(
dbType, connectionAddr string,
debug bool,
cfg types.DatabaseConfig,
notifier *notifier.Notifier,
ipPrefixes []netip.Prefix,
baseDomain string,
) (*HSDatabase, error) {
dbConn, err := openDB(dbType, connectionAddr, debug)
dbConn, err := openDB(cfg)
if err != nil {
return nil, err
}
migrations := gormigrate.New(dbConn, gormigrate.DefaultOptions, []*gormigrate.Migration{
migrations := gormigrate.New(
dbConn,
gormigrate.DefaultOptions,
[]*gormigrate.Migration{
// New migrations should be added as transactions at the end of this list.
// The initial commit here is quite messy, completely out of order and
// has no versioning and is the tech debt of not having versioned migrations
@ -70,7 +63,7 @@ func NewHeadscaleDatabase(
{
ID: "202312101416",
Migrate: func(tx *gorm.DB) error {
if dbType == Postgres {
if cfg.Type == types.DatabasePostgres {
tx.Exec(`create extension if not exists "uuid-ossp";`)
}
@ -78,23 +71,28 @@ func NewHeadscaleDatabase(
// the big rename from Machine to Node
_ = tx.Migrator().RenameTable("machines", "nodes")
_ = tx.Migrator().RenameColumn(&types.Route{}, "machine_id", "node_id")
_ = tx.Migrator().
RenameColumn(&types.Route{}, "machine_id", "node_id")
err = tx.AutoMigrate(types.User{})
if err != nil {
return err
}
_ = tx.Migrator().RenameColumn(&types.Node{}, "namespace_id", "user_id")
_ = tx.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
_ = tx.Migrator().
RenameColumn(&types.Node{}, "namespace_id", "user_id")
_ = tx.Migrator().
RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
_ = tx.Migrator().RenameColumn(&types.Node{}, "ip_address", "ip_addresses")
_ = tx.Migrator().
RenameColumn(&types.Node{}, "ip_address", "ip_addresses")
_ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname")
// GivenName is used as the primary source of DNS names, make sure
// the field is populated and normalized if it was not when the
// node was registered.
_ = tx.Migrator().RenameColumn(&types.Node{}, "nickname", "given_name")
_ = tx.Migrator().
RenameColumn(&types.Node{}, "nickname", "given_name")
// If the Node table has a column for registered,
// find all occourences of "false" and drop them. Then
@ -147,7 +145,9 @@ func NewHeadscaleDatabase(
DiscoKey string
}
var results []result
err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error
err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").
Find(&results).
Error
if err != nil {
return err
}
@ -180,7 +180,8 @@ func NewHeadscaleDatabase(
}
if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") {
log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...")
log.Info().
Msgf("Database has legacy enabled_routes column in node, migrating...")
type NodeAux struct {
ID uint64
@ -188,7 +189,10 @@ func NewHeadscaleDatabase(
}
nodesAux := []NodeAux{}
err := tx.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error
err := tx.Table("nodes").
Select("id, enabled_routes").
Scan(&nodesAux).
Error
if err != nil {
log.Fatal().Err(err).Msg("Error accessing db")
}
@ -234,7 +238,9 @@ func NewHeadscaleDatabase(
err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes")
if err != nil {
log.Error().Err(err).Msg("Error dropping enabled_routes column")
log.Error().
Err(err).
Msg("Error dropping enabled_routes column")
}
}
@ -310,15 +316,15 @@ func NewHeadscaleDatabase(
return nil
},
},
})
},
)
if err = migrations.Migrate(); err != nil {
log.Fatal().Err(err).Msgf("Migration failed: %v", err)
}
db := HSDatabase{
db: dbConn,
notifier: notifier,
DB: dbConn,
ipPrefixes: ipPrefixes,
baseDomain: baseDomain,
@ -327,20 +333,19 @@ func NewHeadscaleDatabase(
return &db, err
}
func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database")
func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
// TODO(kradalby): Integrate this with zerolog
var dbLogger logger.Interface
if debug {
if cfg.Debug {
dbLogger = logger.Default
} else {
dbLogger = logger.Default.LogMode(logger.Silent)
}
switch dbType {
case Sqlite:
switch cfg.Type {
case types.DatabaseSqlite:
db, err := gorm.Open(
sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"),
sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: dbLogger,
@ -359,16 +364,51 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
return db, err
case Postgres:
return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{
case types.DatabasePostgres:
dbString := fmt.Sprintf(
"host=%s dbname=%s user=%s",
cfg.Postgres.Host,
cfg.Postgres.Name,
cfg.Postgres.User,
)
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
if !sslEnabled {
dbString += " sslmode=disable"
}
} else {
dbString += fmt.Sprintf(" sslmode=%s", cfg.Postgres.Ssl)
}
if cfg.Postgres.Port != 0 {
dbString += fmt.Sprintf(" port=%d", cfg.Postgres.Port)
}
if cfg.Postgres.Pass != "" {
dbString += fmt.Sprintf(" password=%s", cfg.Postgres.Pass)
}
db, err := gorm.Open(postgres.Open(dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: dbLogger,
})
if err != nil {
return nil, err
}
sqlDB, _ := db.DB()
sqlDB.SetMaxIdleConns(cfg.Postgres.MaxIdleConnections)
sqlDB.SetMaxOpenConns(cfg.Postgres.MaxOpenConnections)
sqlDB.SetConnMaxIdleTime(
time.Duration(cfg.Postgres.ConnMaxIdleTimeSecs) * time.Second,
)
return db, nil
}
return nil, fmt.Errorf(
"database of type %s is not supported: %w",
dbType,
cfg.Type,
errDatabaseNotSupported,
)
}
@ -376,7 +416,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
sqlDB, err := hsdb.db.DB()
sqlDB, err := hsdb.DB.DB()
if err != nil {
return err
}
@ -385,10 +425,48 @@ func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
}
func (hsdb *HSDatabase) Close() error {
db, err := hsdb.db.DB()
db, err := hsdb.DB.DB()
if err != nil {
return err
}
return db.Close()
}
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
rx := hsdb.DB.Begin()
defer rx.Rollback()
return fn(rx)
}
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
rx := db.Begin()
defer rx.Rollback()
ret, err := fn(rx)
if err != nil {
var no T
return no, err
}
return ret, nil
}
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
tx := hsdb.DB.Begin()
defer tx.Rollback()
if err := fn(tx); err != nil {
return err
}
return tx.Commit().Error
}
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
tx := db.Begin()
defer tx.Rollback()
ret, err := fn(tx)
if err != nil {
var no T
return no, err
}
return ret, tx.Commit().Error
}

View file

@ -34,22 +34,21 @@ var (
)
)
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listPeers(node)
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListPeers(rx, node)
})
}
func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) {
log.Trace().
Caller().
Str("node", node.Hostname).
Msg("Finding direct peers")
nodes := types.Nodes{}
if err := hsdb.db.
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -64,16 +63,15 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
return nodes, nil
}
func (hsdb *HSDatabase) ListNodes() ([]types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listNodes()
func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodes(rx)
})
}
func (hsdb *HSDatabase) listNodes() ([]types.Node, error) {
nodes := []types.Node{}
if err := hsdb.db.
func ListNodes(tx *gorm.DB) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -85,16 +83,9 @@ func (hsdb *HSDatabase) listNodes() ([]types.Node, error) {
return nodes, nil
}
func (hsdb *HSDatabase) ListNodesByGivenName(givenName string) (types.Nodes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listNodesByGivenName(givenName)
}
func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, error) {
func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) {
nodes := types.Nodes{}
if err := hsdb.db.
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -106,12 +97,15 @@ func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, err
return nodes, nil
}
// GetNode finds a Node by name and user and returns the Node struct.
func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return getNode(rx, user, name)
})
}
nodes, err := hsdb.ListNodesByUser(user)
// getNode finds a Node by name and user and returns the Node struct.
func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
nodes, err := ListNodesByUser(tx, user)
if err != nil {
return nil, err
}
@ -125,34 +119,16 @@ func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) {
return nil, ErrNodeNotFound
}
// GetNodeByGivenName finds a Node by given name and user and returns the Node struct.
func (hsdb *HSDatabase) GetNodeByGivenName(
user string,
givenName string,
) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
node := types.Node{}
if err := hsdb.db.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Where("given_name = ?", givenName).First(&node).Error; err != nil {
return nil, err
}
return nil, ErrNodeNotFound
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByID(rx, id)
})
}
// GetNodeByID finds a Node by ID and returns the Node struct.
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) {
mach := types.Node{}
if result := hsdb.db.
if result := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -164,21 +140,19 @@ func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
return &mach, nil
}
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
func (hsdb *HSDatabase) GetNodeByMachineKey(
machineKey key.MachinePublic,
) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getNodeByMachineKey(machineKey)
func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByMachineKey(rx, machineKey)
})
}
func (hsdb *HSDatabase) getNodeByMachineKey(
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
func GetNodeByMachineKey(
tx *gorm.DB,
machineKey key.MachinePublic,
) (*types.Node, error) {
mach := types.Node{}
if result := hsdb.db.
if result := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -190,36 +164,24 @@ func (hsdb *HSDatabase) getNodeByMachineKey(
return &mach, nil
}
// GetNodeByNodeKey finds a Node by its current NodeKey.
func (hsdb *HSDatabase) GetNodeByNodeKey(
func (hsdb *HSDatabase) GetNodeByAnyKey(
machineKey key.MachinePublic,
nodeKey key.NodePublic,
oldNodeKey key.NodePublic,
) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
node := types.Node{}
if result := hsdb.db.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
First(&node, "node_key = ?",
nodeKey.String()); result.Error != nil {
return nil, result.Error
}
return &node, nil
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey)
})
}
// GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct.
func (hsdb *HSDatabase) GetNodeByAnyKey(
// TODO(kradalby): see if we can remove this.
func GetNodeByAnyKey(
tx *gorm.DB,
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
) (*types.Node, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
node := types.Node{}
if result := hsdb.db.
if result := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
@ -234,60 +196,44 @@ func (hsdb *HSDatabase) GetNodeByAnyKey(
return &node, nil
}
func (hsdb *HSDatabase) NodeReloadFromDatabase(node *types.Node) error {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
if result := hsdb.db.Find(node).First(&node); result.Error != nil {
return result.Error
}
return nil
func (hsdb *HSDatabase) SetTags(
nodeID uint64,
tags []string,
) error {
return hsdb.Write(func(tx *gorm.DB) error {
return SetTags(tx, nodeID, tags)
})
}
// SetTags takes a Node struct pointer and update the forced tags.
func (hsdb *HSDatabase) SetTags(
node *types.Node,
func SetTags(
tx *gorm.DB,
nodeID uint64,
tags []string,
) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if len(tags) == 0 {
return nil
}
newTags := []string{}
newTags := types.StringList{}
for _, tag := range tags {
if !util.StringOrPrefixListContains(newTags, tag) {
newTags = append(newTags, tag)
}
}
if err := hsdb.db.Model(node).Updates(types.Node{
ForcedTags: newTags,
}).Error; err != nil {
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
return fmt.Errorf("failed to update tags for node in the database: %w", err)
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from db.SetTags",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
return nil
}
// RenameNode takes a Node struct and a new GivenName for the nodes
// and renames it.
func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
func RenameNode(tx *gorm.DB,
nodeID uint64, newName string,
) error {
err := util.CheckForFQDNRules(
newName,
)
@ -295,129 +241,74 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
log.Error().
Caller().
Str("func", "RenameNode").
Str("node", node.Hostname).
Uint64("nodeID", nodeID).
Str("newName", newName).
Err(err).
Msg("failed to rename node")
return err
}
node.GivenName = newName
if err := hsdb.db.Model(node).Updates(types.Node{
GivenName: newName,
}).Error; err != nil {
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
return fmt.Errorf("failed to rename node in the database: %w", err)
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from db.RenameNode",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
return nil
}
func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error {
return hsdb.Write(func(tx *gorm.DB) error {
return NodeSetExpiry(tx, nodeID, expiry)
})
}
// NodeSetExpiry takes a Node struct and a new expiry time.
func (hsdb *HSDatabase) NodeSetExpiry(node *types.Node, expiry time.Time) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.nodeSetExpiry(node, expiry)
func NodeSetExpiry(tx *gorm.DB,
nodeID uint64, expiry time.Time,
) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
}
func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error {
if err := hsdb.db.Model(node).Updates(types.Node{
Expiry: &expiry,
}).Error; err != nil {
return fmt.Errorf(
"failed to refresh node (update expiration) in the database: %w",
err,
)
}
node.Expiry = &expiry
stateSelfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if stateSelfUpdate.Valid() {
hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &expiry,
},
},
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
return nil
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.MachinePublic]bool) error {
return hsdb.Write(func(tx *gorm.DB) error {
return DeleteNode(tx, node, isConnected)
})
}
// DeleteNode deletes a Node from the database.
func (hsdb *HSDatabase) DeleteNode(node *types.Node) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.deleteNode(node)
}
func (hsdb *HSDatabase) deleteNode(node *types.Node) error {
err := hsdb.deleteNodeRoutes(node)
// Caller is responsible for notifying all of change.
func DeleteNode(tx *gorm.DB,
node *types.Node,
isConnected map[key.MachinePublic]bool,
) error {
err := deleteNodeRoutes(tx, node, map[key.MachinePublic]bool{})
if err != nil {
return err
}
// Unscoped causes the node to be fully removed from the database.
if err := hsdb.db.Unscoped().Delete(&node).Error; err != nil {
if err := tx.Unscoped().Delete(&node).Error; err != nil {
return err
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
}
// UpdateLastSeen sets a node's last seen field indicating that we
// have recently communicating with this node.
// This is mostly used to indicate if a node is online and is not
// extremely important to make sure is fully correct and to avoid
// holding up the hot path, does not contain any locks and isnt
// concurrency safe. But that should be ok.
func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error {
return hsdb.db.Model(node).Updates(types.Node{
LastSeen: node.LastSeen,
}).Error
func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
}
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
func RegisterNodeFromAuthCallback(
tx *gorm.DB,
cache *cache.Cache,
mkey key.MachinePublic,
userName string,
nodeExpiry *time.Time,
registrationMethod string,
ipPrefixes []netip.Prefix,
) (*types.Node, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
log.Debug().
Str("machine_key", mkey.ShortString()).
Str("userName", userName).
@ -427,7 +318,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
if nodeInterface, ok := cache.Get(mkey.String()); ok {
if registrationNode, ok := nodeInterface.(types.Node); ok {
user, err := hsdb.getUser(userName)
user, err := GetUser(tx, userName)
if err != nil {
return nil, fmt.Errorf(
"failed to find user in register node from auth callback, %w",
@ -442,14 +333,17 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
}
registrationNode.UserID = user.ID
registrationNode.User = *user
registrationNode.RegisterMethod = registrationMethod
if nodeExpiry != nil {
registrationNode.Expiry = nodeExpiry
}
node, err := hsdb.registerNode(
node, err := RegisterNode(
tx,
registrationNode,
ipPrefixes,
)
if err == nil {
@ -465,15 +359,14 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
return nil, ErrNodeNotFoundRegistrationCache
}
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.registerNode(node)
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
return RegisterNode(tx, node, hsdb.ipPrefixes)
})
}
func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) {
log.Debug().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
@ -485,7 +378,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
// so we store the node.Expire and node.Nodekey that has been set when
// adding it to the registrationCache
if len(node.IPAddresses) > 0 {
if err := hsdb.db.Save(&node).Error; err != nil {
if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register existing node in the database: %w", err)
}
@ -500,10 +393,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
return &node, nil
}
hsdb.ipAllocationMutex.Lock()
defer hsdb.ipAllocationMutex.Unlock()
ips, err := hsdb.getAvailableIPs()
ips, err := getAvailableIPs(tx, ipPrefixes)
if err != nil {
log.Error().
Caller().
@ -516,7 +406,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
node.IPAddresses = ips
if err := hsdb.db.Save(&node).Error; err != nil {
if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
}
@ -530,61 +420,50 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
}
// NodeSetNodeKey sets the node key of a node and saves it to the database.
func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if err := hsdb.db.Model(node).Updates(types.Node{
func NodeSetNodeKey(tx *gorm.DB, node *types.Node, nodeKey key.NodePublic) error {
return tx.Model(node).Updates(types.Node{
NodeKey: nodeKey,
}).Error; err != nil {
return err
}
return nil
}).Error
}
// NodeSetMachineKey sets the node key of a node and saves it to the database.
func (hsdb *HSDatabase) NodeSetMachineKey(
node *types.Node,
machineKey key.MachinePublic,
) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.Write(func(tx *gorm.DB) error {
return NodeSetMachineKey(tx, node, machineKey)
})
}
if err := hsdb.db.Model(node).Updates(types.Node{
// NodeSetMachineKey sets the node key of a node and saves it to the database.
func NodeSetMachineKey(
tx *gorm.DB,
node *types.Node,
machineKey key.MachinePublic,
) error {
return tx.Model(node).Updates(types.Node{
MachineKey: machineKey,
}).Error; err != nil {
return err
}
return nil
}).Error
}
// NodeSave saves a node object to the database, prefer to use a specific save method rather
// than this. It is intended to be used when we are changing or.
func (hsdb *HSDatabase) NodeSave(node *types.Node) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
// TODO(kradalby): Remove this func, just use Save.
func NodeSave(tx *gorm.DB, node *types.Node) error {
return tx.Save(node).Error
}
if err := hsdb.db.Save(node).Error; err != nil {
return err
}
return nil
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
return GetAdvertisedRoutes(rx, node)
})
}
// GetAdvertisedRoutes returns the routes that are be advertised by the given node.
func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getAdvertisedRoutes(node)
}
func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) {
func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
routes := types.Routes{}
err := hsdb.db.
err := tx.
Preload("Node").
Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
@ -605,18 +484,17 @@ func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, e
return prefixes, nil
}
// GetEnabledRoutes returns the routes that are enabled for the node.
func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getEnabledRoutes(node)
return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) {
return GetEnabledRoutes(rx, node)
})
}
func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, error) {
// GetEnabledRoutes returns the routes that are enabled for the node.
func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) {
routes := types.Routes{}
err := hsdb.db.
err := tx.
Preload("Node").
Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true).
Find(&routes).Error
@ -638,16 +516,13 @@ func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, erro
return prefixes, nil
}
func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool {
route, err := netip.ParsePrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := hsdb.getEnabledRoutes(node)
enabledRoutes, err := GetEnabledRoutes(tx, node)
if err != nil {
log.Error().Err(err).Msg("Could not get enabled routes")
@ -663,26 +538,37 @@ func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool
return false
}
func (hsdb *HSDatabase) enableRoutes(
node *types.Node,
routeStrs ...string,
) (*types.StateUpdate, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return enableRoutes(tx, node, routeStrs...)
})
}
// enableRoutes enables new routes based on a list of new routes.
func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error {
func enableRoutes(tx *gorm.DB,
node *types.Node, routeStrs ...string,
) (*types.StateUpdate, error) {
newRoutes := make([]netip.Prefix, len(routeStrs))
for index, routeStr := range routeStrs {
route, err := netip.ParsePrefix(routeStr)
if err != nil {
return err
return nil, err
}
newRoutes[index] = route
}
advertisedRoutes, err := hsdb.getAdvertisedRoutes(node)
advertisedRoutes, err := GetAdvertisedRoutes(tx, node)
if err != nil {
return err
return nil, err
}
for _, newRoute := range newRoutes {
if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
return fmt.Errorf(
return nil, fmt.Errorf(
"route (%s) is not available on node %s: %w",
node.Hostname,
newRoute, ErrNodeRouteIsNotAvailable,
@ -693,7 +579,7 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
// Separate loop so we don't leave things in a half-updated state
for _, prefix := range newRoutes {
route := types.Route{}
err := hsdb.db.Preload("Node").
err := tx.Preload("Node").
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
First(&route).Error
if err == nil {
@ -702,23 +588,23 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
// Mark already as primary if there is only this node offering this subnet
// (and is not an exit route)
if !route.IsExitRoute() {
route.IsPrimary = hsdb.isUniquePrefix(route)
route.IsPrimary = isUniquePrefix(tx, route)
}
err = hsdb.db.Save(&route).Error
err = tx.Save(&route).Error
if err != nil {
return fmt.Errorf("failed to enable route: %w", err)
return nil, fmt.Errorf("failed to enable route: %w", err)
}
} else {
return fmt.Errorf("failed to find route: %w", err)
return nil, fmt.Errorf("failed to find route: %w", err)
}
}
// Ensure the node has the latest routes when notifying the other
// nodes
nRoutes, err := hsdb.getNodeRoutes(node)
nRoutes, err := GetNodeRoutes(tx, node)
if err != nil {
return fmt.Errorf("failed to read back routes: %w", err)
return nil, fmt.Errorf("failed to read back routes: %w", err)
}
node.Routes = nRoutes
@ -729,17 +615,11 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
Strs("routes", routeStrs).
Msg("enabling routes")
stateUpdate := types.StateUpdate{
return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from db.enableRoutes",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(
stateUpdate, node.MachineKey.String())
}
return nil
Message: "created in db.enableRoutes",
}, nil
}
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
@ -772,16 +652,23 @@ func (hsdb *HSDatabase) GenerateGivenName(
mkey key.MachinePublic,
suppliedName string,
) (string, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return Read(hsdb.DB, func(rx *gorm.DB) (string, error) {
return GenerateGivenName(rx, mkey, suppliedName)
})
}
func GenerateGivenName(
tx *gorm.DB,
mkey key.MachinePublic,
suppliedName string,
) (string, error) {
givenName, err := generateGivenName(suppliedName, false)
if err != nil {
return "", err
}
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
nodes, err := hsdb.listNodesByGivenName(givenName)
nodes, err := listNodesByGivenName(tx, givenName)
if err != nil {
return "", err
}
@ -805,29 +692,28 @@ func (hsdb *HSDatabase) GenerateGivenName(
return givenName, nil
}
func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
users, err := hsdb.listUsers()
func ExpireEphemeralNodes(tx *gorm.DB,
inactivityThreshhold time.Duration,
) (types.StateUpdate, bool) {
users, err := ListUsers(tx)
if err != nil {
log.Error().Err(err).Msg("Error listing users")
return
return types.StateUpdate{}, false
}
expired := make([]tailcfg.NodeID, 0)
for _, user := range users {
nodes, err := hsdb.listNodesByUser(user.Name)
nodes, err := ListNodesByUser(tx, user.Name)
if err != nil {
log.Error().
Err(err).
Str("user", user.Name).
Msg("Error listing nodes in user")
return
return types.StateUpdate{}, false
}
expired := make([]tailcfg.NodeID, 0)
for idx, node := range nodes {
if node.IsEphemeral() && node.LastSeen != nil &&
time.Now().
@ -838,7 +724,8 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration)
Str("node", node.Hostname).
Msg("Ephemeral client removed from database")
err = hsdb.deleteNode(nodes[idx])
// empty isConnected map as ephemeral nodes are not routes
err = DeleteNode(tx, nodes[idx], map[key.MachinePublic]bool{})
if err != nil {
log.Error().
Err(err).
@ -848,33 +735,35 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration)
}
}
// TODO(kradalby): needs to be moved out of transaction
}
if len(expired) > 0 {
hsdb.notifier.NotifyAll(types.StateUpdate{
return types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: expired,
})
}
}, true
}
return types.StateUpdate{}, false
}
func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
func ExpireExpiredNodes(tx *gorm.DB,
lastCheck time.Time,
) (time.Time, types.StateUpdate, bool) {
// use the time of the start of the function to ensure we
// dont miss some nodes by returning it _after_ we have
// checked everything.
started := time.Now()
expiredNodes := make([]*types.Node, 0)
expired := make([]*tailcfg.PeerChange, 0)
nodes, err := hsdb.listNodes()
nodes, err := ListNodes(tx)
if err != nil {
log.Error().
Err(err).
Msg("Error listing nodes to find expired nodes")
return time.Unix(0, 0)
return time.Unix(0, 0), types.StateUpdate{}, false
}
for index, node := range nodes {
if node.IsExpired() &&
@ -882,13 +771,17 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
// It will notify about all nodes that has been expired.
// It should only notify about expired nodes since _last check_.
node.Expiry.After(lastCheck) {
expiredNodes = append(expiredNodes, &nodes[index])
expired = append(expired, &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: node.Expiry,
})
now := time.Now()
// Do not use setNodeExpiry as that has a notifier hook, which
// can cause a deadlock, we are updating all changed nodes later
// and there is no point in notifiying twice.
if err := hsdb.db.Model(nodes[index]).Updates(types.Node{
Expiry: &started,
if err := tx.Model(&nodes[index]).Updates(types.Node{
Expiry: &now,
}).Error; err != nil {
log.Error().
Err(err).
@ -904,33 +797,12 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
}
}
expired := make([]*tailcfg.PeerChange, len(expiredNodes))
for idx, node := range expiredNodes {
expired[idx] = &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &started,
}
}
// Inform the peers of a node with a lightweight update.
stateUpdate := types.StateUpdate{
if len(expired) > 0 {
return started, types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: expired,
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}, true
}
// Inform the node itself that it has expired.
for _, node := range expiredNodes {
stateSelfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if stateSelfUpdate.Valid() {
hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey)
}
}
return started
return started, types.StateUpdate{}, false
}

View file

@ -23,7 +23,7 @@ func (s *Suite) TestGetNode(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
@ -38,9 +38,9 @@ func (s *Suite) TestGetNode(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(node)
db.DB.Save(node)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
}
@ -66,40 +66,12 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
_, err = db.GetNodeByID(0)
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetNodeByNodeKey(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
_, err = db.GetNodeByNodeKey(nodeKey.Public())
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
@ -124,7 +96,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
_, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
@ -146,12 +118,12 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(1),
}
db.db.Save(&node)
db.DB.Save(&node)
err = db.DeleteNode(&node)
err = db.DeleteNode(&node, map[key.MachinePublic]bool{})
c.Assert(err, check.IsNil)
_, err = db.GetNode(user.Name, "testnode3")
_, err = db.getNode(user.Name, "testnode3")
c.Assert(err, check.NotNil)
}
@ -178,7 +150,7 @@ func (s *Suite) TestListPeers(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
}
node0ByID, err := db.GetNodeByID(0)
@ -228,7 +200,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(stor[index%2].key.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
}
aclPolicy := &policy.ACLPolicy{
@ -295,7 +267,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
@ -311,16 +283,19 @@ func (s *Suite) TestExpireNode(c *check.C) {
AuthKeyID: uint(pak.ID),
Expiry: &time.Time{},
}
db.db.Save(node)
db.DB.Save(node)
nodeFromDB, err := db.GetNode("test", "testnode")
nodeFromDB, err := db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(nodeFromDB, check.NotNil)
c.Assert(nodeFromDB.IsExpired(), check.Equals, false)
now := time.Now()
err = db.NodeSetExpiry(nodeFromDB, now)
err = db.NodeSetExpiry(nodeFromDB.ID, now)
c.Assert(err, check.IsNil)
nodeFromDB, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
@ -354,7 +329,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("user-1", "testnode")
_, err = db.getNode("user-1", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
@ -372,7 +347,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(node)
db.DB.Save(node)
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
@ -397,7 +372,7 @@ func (s *Suite) TestSetTags(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "testnode")
_, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
@ -412,21 +387,21 @@ func (s *Suite) TestSetTags(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(node)
db.DB.Save(node)
// assign simple tags
sTags := []string{"tag:test", "tag:foo"}
err = db.SetTags(node, sTags)
err = db.SetTags(node.ID, sTags)
c.Assert(err, check.IsNil)
node, err = db.GetNode("test", "testnode")
node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
// assign duplicat tags, expect no errors but no doubles in DB
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
err = db.SetTags(node, eTags)
err = db.SetTags(node.ID, eTags)
c.Assert(err, check.IsNil)
node, err = db.GetNode("test", "testnode")
node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(
node.ForcedTags,
@ -601,7 +576,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
}
db.db.Save(&node)
db.DB.Save(&node)
sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
@ -610,7 +585,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
node0ByID, err := db.GetNodeByID(0)
c.Assert(err, check.IsNil)
err = db.EnableAutoApprovedRoutes(pol, node0ByID)
// TODO(kradalby): Check state update
_, err = db.EnableAutoApprovedRoutes(pol, node0ByID)
c.Assert(err, check.IsNil)
enabledRoutes, err := db.GetEnabledRoutes(node0ByID)

View file

@ -20,7 +20,6 @@ var (
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
)
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func (hsdb *HSDatabase) CreatePreAuthKey(
userName string,
reusable bool,
@ -28,11 +27,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
// TODO(kradalby): figure out this lock
// hsdb.mu.Lock()
// defer hsdb.mu.Unlock()
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags)
})
}
user, err := hsdb.GetUser(userName)
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func CreatePreAuthKey(
tx *gorm.DB,
userName string,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
if err != nil {
return nil, err
}
@ -48,7 +57,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
}
now := time.Now().UTC()
kstr, err := hsdb.generateKey()
kstr, err := generateKey()
if err != nil {
return nil, err
}
@ -63,9 +72,8 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
Expiration: expiration,
}
err = hsdb.db.Transaction(func(db *gorm.DB) error {
if err := db.Save(&key).Error; err != nil {
return fmt.Errorf("failed to create key in the database: %w", err)
if err := tx.Save(&key).Error; err != nil {
return nil, fmt.Errorf("failed to create key in the database: %w", err)
}
if len(aclTags) > 0 {
@ -73,8 +81,8 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
for _, tag := range aclTags {
if !seenTags[tag] {
if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return fmt.Errorf(
if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return nil, fmt.Errorf(
"failed to ceate key tag in the database: %w",
err,
)
@ -84,9 +92,6 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
}
}
return nil
})
if err != nil {
return nil, err
}
@ -94,22 +99,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
return &key, nil
}
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listPreAuthKeys(userName)
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
return ListPreAuthKeys(rx, userName)
})
}
func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
user, err := hsdb.getUser(userName)
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
if err != nil {
return nil, err
}
keys := []types.PreAuthKey{}
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
if err := tx.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
return nil, err
}
@ -117,11 +121,8 @@ func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, er
}
// GetPreAuthKey returns a PreAuthKey for a given key.
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
pak, err := hsdb.ValidatePreAuthKey(key)
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
pak, err := ValidatePreAuthKey(tx, key)
if err != nil {
return nil, err
}
@ -135,15 +136,8 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist.
func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.destroyPreAuthKey(pak)
}
func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
return hsdb.db.Transaction(func(db *gorm.DB) error {
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
return tx.Transaction(func(db *gorm.DB) error {
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
return result.Error
}
@ -156,12 +150,15 @@ func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error {
})
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.Write(func(tx *gorm.DB) error {
return ExpirePreAuthKey(tx, k)
})
}
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}
@ -169,26 +166,26 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
}
// UsePreAuthKey marks a PreAuthKey as used.
func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
k.Used = true
if err := hsdb.db.Save(k).Error; err != nil {
if err := tx.Save(k).Error; err != nil {
return fmt.Errorf("failed to update key used status in the database: %w", err)
}
return nil
}
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return ValidatePreAuthKey(rx, k)
})
}
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used.
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
pak := types.PreAuthKey{}
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
if result := tx.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
@ -204,7 +201,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
}
nodes := types.Nodes{}
if err := hsdb.db.
if err := tx.
Preload("AuthKey").
Where(&types.Node{AuthKeyID: uint(pak.ID)}).
Find(&nodes).Error; err != nil {
@ -218,7 +215,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
return &pak, nil
}
func (hsdb *HSDatabase) generateKey() (string, error) {
func generateKey() (string, error) {
size := 24
bytes := make([]byte, size)
if _, err := rand.Read(bytes); err != nil {

View file

@ -6,6 +6,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1"
"gorm.io/gorm"
)
func (*Suite) TestCreatePreAuthKey(c *check.C) {
@ -41,7 +42,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
user, err := db.CreateUser("test2")
c.Assert(err, check.IsNil)
now := time.Now()
now := time.Now().Add(-5 * time.Second)
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
c.Assert(err, check.IsNil)
@ -82,7 +83,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
@ -103,7 +104,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil)
@ -138,19 +139,22 @@ func (*Suite) TestEphemeralKey(c *check.C) {
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
_, err = db.ValidatePreAuthKey(pak.Key)
// Ephemeral keys are by definition reusable
c.Assert(err, check.IsNil)
_, err = db.GetNode("test7", "testest")
_, err = db.getNode("test7", "testest")
c.Assert(err, check.IsNil)
db.ExpireEphemeralNodes(time.Second * 20)
db.DB.Transaction(func(tx *gorm.DB) error {
ExpireEphemeralNodes(tx, time.Second*20)
return nil
})
// The machine record should have been deleted
_, err = db.GetNode("test7", "testest")
_, err = db.getNode("test7", "testest")
c.Assert(err, check.NotNil)
}
@ -178,7 +182,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
pak.Used = true
db.db.Save(&pak)
db.DB.Save(&pak)
_, err = db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)

View file

@ -7,23 +7,15 @@ import (
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"gorm.io/gorm"
"tailscale.com/types/key"
)
var ErrRouteIsNotAvailable = errors.New("route is not available")
func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getRoutes()
}
func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
func GetRoutes(tx *gorm.DB) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
Find(&routes).Error
@ -34,9 +26,9 @@ func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
return routes, nil
}
func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
func getAdvertisedAndEnabledRoutes(tx *gorm.DB) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
Where("advertised = ? AND enabled = ?", true, true).
@ -48,9 +40,9 @@ func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
return routes, nil
}
func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) {
func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
Where("prefix = ?", types.IPPrefix(pref)).
@ -62,16 +54,9 @@ func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, erro
return routes, nil
}
func (hsdb *HSDatabase) GetNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getNodeAdvertisedRoutes(node)
}
func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) {
func GetNodeAdvertisedRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
Where("node_id = ? AND advertised = true", node.ID).
@ -84,15 +69,14 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes,
}
func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getNodeRoutes(node)
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetNodeRoutes(rx, node)
})
}
func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
func GetNodeRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
Where("node_id = ?", node.ID).
@ -104,16 +88,9 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
return routes, nil
}
func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getRoute(id)
}
func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
func GetRoute(tx *gorm.DB, id uint64) (*types.Route, error) {
var route types.Route
err := hsdb.db.
err := tx.
Preload("Node").
Preload("Node.User").
First(&route, id).Error
@ -124,40 +101,34 @@ func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
return &route, nil
}
func (hsdb *HSDatabase) EnableRoute(id uint64) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.enableRoute(id)
}
func (hsdb *HSDatabase) enableRoute(id uint64) error {
route, err := hsdb.getRoute(id)
func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
route, err := GetRoute(tx, id)
if err != nil {
return err
return nil, err
}
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if route.IsExitRoute() {
return hsdb.enableRoutes(
return enableRoutes(
tx,
&route.Node,
types.ExitRouteV4.String(),
types.ExitRouteV6.String(),
)
}
return hsdb.enableRoutes(&route.Node, netip.Prefix(route.Prefix).String())
return enableRoutes(tx, &route.Node, netip.Prefix(route.Prefix).String())
}
func (hsdb *HSDatabase) DisableRoute(id uint64) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
route, err := hsdb.getRoute(id)
func DisableRoute(tx *gorm.DB,
id uint64,
isConnected map[key.MachinePublic]bool,
) (*types.StateUpdate, error) {
route, err := GetRoute(tx, id)
if err != nil {
return err
return nil, err
}
var routes types.Routes
@ -166,64 +137,79 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
var update *types.StateUpdate
if !route.IsExitRoute() {
err = hsdb.failoverRouteWithNotify(route)
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
if err != nil {
return err
return nil, err
}
route.Enabled = false
route.IsPrimary = false
err = hsdb.db.Save(route).Error
err = tx.Save(route).Error
if err != nil {
return err
return nil, err
}
} else {
routes, err = hsdb.getNodeRoutes(&node)
routes, err = GetNodeRoutes(tx, &node)
if err != nil {
return err
return nil, err
}
for i := range routes {
if routes[i].IsExitRoute() {
routes[i].Enabled = false
routes[i].IsPrimary = false
err = hsdb.db.Save(&routes[i]).Error
err = tx.Save(&routes[i]).Error
if err != nil {
return err
return nil, err
}
}
}
}
if routes == nil {
routes, err = hsdb.getNodeRoutes(&node)
routes, err = GetNodeRoutes(tx, &node)
if err != nil {
return err
return nil, err
}
}
node.Routes = routes
stateUpdate := types.StateUpdate{
// If update is empty, it means that one was not created
// by failover (as a failover was not necessary), create
// one and return to the caller.
if update == nil {
update = &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{&node},
ChangeNodes: types.Nodes{
&node,
},
Message: "called from db.DisableRoute",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
return update, nil
}
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
func (hsdb *HSDatabase) DeleteRoute(
id uint64,
isConnected map[key.MachinePublic]bool,
) (*types.StateUpdate, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return DeleteRoute(tx, id, isConnected)
})
}
route, err := hsdb.getRoute(id)
func DeleteRoute(
tx *gorm.DB,
id uint64,
isConnected map[key.MachinePublic]bool,
) (*types.StateUpdate, error) {
route, err := GetRoute(tx, id)
if err != nil {
return err
return nil, err
}
var routes types.Routes
@ -232,19 +218,20 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
var update *types.StateUpdate
if !route.IsExitRoute() {
err := hsdb.failoverRouteWithNotify(route)
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
if err != nil {
return nil
return nil, nil
}
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
return err
if err := tx.Unscoped().Delete(&route).Error; err != nil {
return nil, err
}
} else {
routes, err := hsdb.getNodeRoutes(&node)
routes, err := GetNodeRoutes(tx, &node)
if err != nil {
return err
return nil, err
}
routesToDelete := types.Routes{}
@ -254,56 +241,59 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
}
}
if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
return err
if err := tx.Unscoped().Delete(&routesToDelete).Error; err != nil {
return nil, err
}
}
// If update is empty, it means that one was not created
// by failover (as a failover was not necessary), create
// one and return to the caller.
if routes == nil {
routes, err = hsdb.getNodeRoutes(&node)
routes, err = GetNodeRoutes(tx, &node)
if err != nil {
return err
return nil, err
}
}
node.Routes = routes
stateUpdate := types.StateUpdate{
if update == nil {
update = &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{&node},
ChangeNodes: types.Nodes{
&node,
},
Message: "called from db.DeleteRoute",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
return update, nil
}
func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error {
routes, err := hsdb.getNodeRoutes(node)
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error {
routes, err := GetNodeRoutes(tx, node)
if err != nil {
return err
}
for i := range routes {
if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil {
if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
return err
}
// TODO(kradalby): This is a bit too aggressive, we could probably
// figure out which routes needs to be failed over rather than all.
hsdb.failoverRouteWithNotify(&routes[i])
failoverRouteReturnUpdate(tx, isConnected, &routes[i])
}
return nil
}
// isUniquePrefix returns if there is another node providing the same route already.
func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
var count int64
hsdb.db.
Model(&types.Route{}).
tx.Model(&types.Route{}).
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
route.Prefix,
route.NodeID,
@ -312,9 +302,9 @@ func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
return count == 0
}
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) {
func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
var route types.Route
err := hsdb.db.
err := tx.
Preload("Node").
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
First(&route).Error
@ -329,14 +319,17 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro
return &route, nil
}
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetNodePrimaryRoutes(rx, node)
})
}
// getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
// Exit nodes are not considered for this, as they are never marked as Primary.
func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
func GetNodePrimaryRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
err := tx.
Preload("Node").
Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true).
Find(&routes).Error
@ -347,22 +340,21 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er
return routes, nil
}
// SaveNodeRoutes takes a node and updates the database with
// the new routes.
// It returns a bool wheter an update should be sent as the
// saved route impacts nodes.
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.saveNodeRoutes(node)
return Write(hsdb.DB, func(tx *gorm.DB) (bool, error) {
return SaveNodeRoutes(tx, node)
})
}
func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
// SaveNodeRoutes takes a node and updates the database with
// the new routes.
// It returns a bool whether an update should be sent as the
// saved route impacts nodes.
func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
sendUpdate := false
currentRoutes := types.Routes{}
err := hsdb.db.Where("node_id = ?", node.ID).Find(&currentRoutes).Error
err := tx.Where("node_id = ?", node.ID).Find(&currentRoutes).Error
if err != nil {
return sendUpdate, err
}
@ -382,7 +374,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
if !route.Advertised {
currentRoutes[pos].Advertised = true
err := hsdb.db.Save(&currentRoutes[pos]).Error
err := tx.Save(&currentRoutes[pos]).Error
if err != nil {
return sendUpdate, err
}
@ -398,7 +390,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
} else if route.Advertised {
currentRoutes[pos].Advertised = false
currentRoutes[pos].Enabled = false
err := hsdb.db.Save(&currentRoutes[pos]).Error
err := tx.Save(&currentRoutes[pos]).Error
if err != nil {
return sendUpdate, err
}
@ -413,7 +405,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
Advertised: true,
Enabled: false,
}
err := hsdb.db.Create(&route).Error
err := tx.Create(&route).Error
if err != nil {
return sendUpdate, err
}
@ -425,127 +417,89 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route
// currently have a functioning host that exposes the network.
func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error {
nodeRoutes, err := hsdb.getNodeRoutes(node)
func EnsureFailoverRouteIsAvailable(
tx *gorm.DB,
isConnected map[key.MachinePublic]bool,
node *types.Node,
) (*types.StateUpdate, error) {
nodeRoutes, err := GetNodeRoutes(tx, node)
if err != nil {
return nil
return nil, nil
}
var changedNodes types.Nodes
for _, nodeRoute := range nodeRoutes {
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix))
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
if err != nil {
return err
return nil, err
}
for _, route := range routes {
if route.IsPrimary {
// if we have a primary route, and the node is connected
// nothing needs to be done.
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
if isConnected[route.Node.MachineKey] {
continue
}
// if not, we need to failover the route
err := hsdb.failoverRouteWithNotify(&route)
update, err := failoverRouteReturnUpdate(tx, isConnected, &route)
if err != nil {
return err
return nil, err
}
if update != nil {
changedNodes = append(changedNodes, update.ChangeNodes...)
}
}
}
}
return nil
}
func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error {
routes, err := hsdb.getNodeRoutes(node)
if err != nil {
return nil
}
var changedKeys []key.MachinePublic
for _, route := range routes {
changed, err := hsdb.failoverRoute(&route)
if err != nil {
return err
}
changedKeys = append(changedKeys, changed...)
}
changedKeys = lo.Uniq(changedKeys)
var nodes types.Nodes
for _, key := range changedKeys {
node, err := hsdb.GetNodeByMachineKey(key)
if err != nil {
return err
}
nodes = append(nodes, node)
}
if nodes != nil {
stateUpdate := types.StateUpdate{
if len(changedNodes) != 0 {
return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: nodes,
Message: "called from db.FailoverNodeRoutesWithNotify",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
ChangeNodes: changedNodes,
Message: "called from db.EnsureFailoverRouteIsAvailable",
}, nil
}
return nil
return nil, nil
}
func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
changedKeys, err := hsdb.failoverRoute(r)
func failoverRouteReturnUpdate(
tx *gorm.DB,
isConnected map[key.MachinePublic]bool,
r *types.Route,
) (*types.StateUpdate, error) {
changedKeys, err := failoverRoute(tx, isConnected, r)
if err != nil {
return err
return nil, err
}
log.Trace().
Interface("isConnected", isConnected).
Interface("changedKeys", changedKeys).
Msg("building route failover")
if len(changedKeys) == 0 {
return nil
return nil, nil
}
var nodes types.Nodes
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("loading machines with new primary routes from db")
for _, key := range changedKeys {
node, err := hsdb.getNodeByMachineKey(key)
node, err := GetNodeByMachineKey(tx, key)
if err != nil {
return err
return nil, err
}
nodes = append(nodes, node)
}
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("notifying peers about primary route change")
if nodes != nil {
stateUpdate := types.StateUpdate{
return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: nodes,
Message: "called from db.failoverRouteWithNotify",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
}
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("notified peers about primary route change")
return nil
Message: "called from db.failoverRouteReturnUpdate",
}, nil
}
// failoverRoute takes a route that is no longer available,
@ -556,12 +510,16 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
//
// and tries to find a new route to take over its place.
// If the given route was not primary, it returns early.
func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) {
func failoverRoute(
tx *gorm.DB,
isConnected map[key.MachinePublic]bool,
r *types.Route,
) ([]key.MachinePublic, error) {
if r == nil {
return nil, nil
}
// This route is not a primary route, and it isnt
// This route is not a primary route, and it is not
// being served to nodes.
if !r.IsPrimary {
return nil, nil
@ -572,7 +530,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
return nil, nil
}
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix))
routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix))
if err != nil {
return nil, err
}
@ -585,14 +543,18 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
continue
}
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
if !route.Enabled {
continue
}
if isConnected[route.Node.MachineKey] {
newPrimary = &routes[idx]
break
}
}
// If a new route was not found/available,
// return with an error.
// return without an error.
// We do not want to update the database as
// the one currently marked as primary is the
// best we got.
@ -606,7 +568,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
// Remove primary from the old route
r.IsPrimary = false
err = hsdb.db.Save(&r).Error
err = tx.Save(&r).Error
if err != nil {
log.Error().Err(err).Msg("error disabling new primary route")
@ -619,7 +581,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
// Set primary for the new primary
newPrimary.IsPrimary = true
err = hsdb.db.Save(&newPrimary).Error
err = tx.Save(&newPrimary).Error
if err != nil {
log.Error().Err(err).Msg("error enabling new primary route")
@ -634,19 +596,26 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
}
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
aclPolicy *policy.ACLPolicy,
node *types.Node,
) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
) (*types.StateUpdate, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
})
}
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
func EnableAutoApprovedRoutes(
tx *gorm.DB,
aclPolicy *policy.ACLPolicy,
node *types.Node,
) (*types.StateUpdate, error) {
if len(node.IPAddresses) == 0 {
return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
}
routes, err := hsdb.getNodeAdvertisedRoutes(node)
routes, err := GetNodeAdvertisedRoutes(tx, node)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().
Caller().
@ -654,9 +623,11 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
Str("node", node.Hostname).
Msg("Could not get advertised routes for node")
return err
return nil, err
}
log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
approvedRoutes := types.Routes{}
for _, advertisedRoute := range routes {
@ -673,9 +644,16 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
Uint64("nodeId", node.ID).
Msg("Failed to resolve autoApprovers for advertised route")
return err
return nil, err
}
log.Trace().
Str("node", node.Hostname).
Str("user", node.User.Name).
Strs("routeApprovers", routeApprovers).
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
Msg("looking up route for autoapproving")
for _, approvedAlias := range routeApprovers {
if approvedAlias == node.User.Name {
approvedRoutes = append(approvedRoutes, advertisedRoute)
@ -687,7 +665,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
Str("alias", approvedAlias).
Msg("Failed to expand alias when processing autoApprovers policy")
return err
return nil, err
}
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
@ -698,17 +676,25 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
}
}
update := &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{},
Message: "created in db.EnableAutoApprovedRoutes",
}
for _, approvedRoute := range approvedRoutes {
err := hsdb.enableRoute(uint64(approvedRoute.ID))
perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID))
if err != nil {
log.Err(err).
Str("approvedRoute", approvedRoute.String()).
Uint64("nodeId", node.ID).
Msg("Failed to enable approved route")
return err
}
return nil, err
}
return nil
update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...)
}
return update, nil
}

View file

@ -24,7 +24,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_get_route_node")
_, err = db.getNode("test", "test_get_route_node")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix("10.0.0.0/24")
@ -42,7 +42,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo,
}
db.db.Save(&node)
db.DB.Save(&node)
su, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
@ -52,10 +52,11 @@ func (s *Suite) TestGetRoutes(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(len(advertisedRoutes), check.Equals, 1)
err = db.enableRoutes(&node, "192.168.0.0/24")
// TODO(kradalby): check state update
_, err = db.enableRoutes(&node, "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = db.enableRoutes(&node, "10.0.0.0/24")
_, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil)
}
@ -66,7 +67,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node")
_, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix(
@ -91,7 +92,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo,
}
db.db.Save(&node)
db.DB.Save(&node)
sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
@ -106,10 +107,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(len(noEnabledRoutes), check.Equals, 0)
err = db.enableRoutes(&node, "192.168.0.0/24")
_, err = db.enableRoutes(&node, "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = db.enableRoutes(&node, "10.0.0.0/24")
_, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes, err := db.GetEnabledRoutes(&node)
@ -117,14 +118,14 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
c.Assert(len(enabledRoutes), check.Equals, 1)
// Adding it twice will just let it pass through
err = db.enableRoutes(&node, "10.0.0.0/24")
_, err = db.enableRoutes(&node, "10.0.0.0/24")
c.Assert(err, check.IsNil)
enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node)
c.Assert(err, check.IsNil)
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
err = db.enableRoutes(&node, "150.0.10.0/25")
_, err = db.enableRoutes(&node, "150.0.10.0/25")
c.Assert(err, check.IsNil)
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node)
@ -139,7 +140,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node")
_, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix(
@ -163,16 +164,16 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo1,
}
db.db.Save(&node1)
db.DB.Save(&node1)
sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, route.String())
_, err = db.enableRoutes(&node1, route.String())
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, route2.String())
_, err = db.enableRoutes(&node1, route2.String())
c.Assert(err, check.IsNil)
hostInfo2 := tailcfg.Hostinfo{
@ -186,13 +187,13 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo2,
}
db.db.Save(&node2)
db.DB.Save(&node2)
sendUpdate, err = db.SaveNodeRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node2, route2.String())
_, err = db.enableRoutes(&node2, route2.String())
c.Assert(err, check.IsNil)
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
@ -219,7 +220,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node")
_, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix(
@ -246,22 +247,23 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
Hostinfo: &hostInfo1,
LastSeen: &now,
}
db.db.Save(&node1)
db.DB.Save(&node1)
sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, prefix.String())
_, err = db.enableRoutes(&node1, prefix.String())
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, prefix2.String())
_, err = db.enableRoutes(&node1, prefix2.String())
c.Assert(err, check.IsNil)
routes, err := db.GetNodeRoutes(&node1)
c.Assert(err, check.IsNil)
err = db.DeleteRoute(uint64(routes[0].ID))
// TODO(kradalby): check stateupdate
_, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{})
c.Assert(err, check.IsNil)
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
@ -269,17 +271,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
c.Assert(len(enabledRoutes1), check.Equals, 1)
}
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
func TestFailoverRoute(t *testing.T) {
ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
// TODO(kradalby): Count/verify updates
var sink chan types.StateUpdate
go func() {
for range sink {
}
}()
machineKeys := []key.MachinePublic{
key.NewMachine().Public(),
key.NewMachine().Public(),
@ -291,6 +285,7 @@ func TestFailoverRoute(t *testing.T) {
name string
failingRoute types.Route
routes types.Routes
isConnected map[key.MachinePublic]bool
want []key.MachinePublic
wantErr bool
}{
@ -371,6 +366,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0],
},
IsPrimary: true,
Enabled: true,
},
routes: types.Routes{
types.Route{
@ -382,6 +378,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0],
},
IsPrimary: true,
Enabled: true,
},
types.Route{
Model: gorm.Model{
@ -392,8 +389,13 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[1],
},
IsPrimary: false,
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: false,
machineKeys[1]: true,
},
want: []key.MachinePublic{
machineKeys[0],
machineKeys[1],
@ -411,6 +413,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0],
},
IsPrimary: false,
Enabled: true,
},
routes: types.Routes{
types.Route{
@ -422,6 +425,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0],
},
IsPrimary: true,
Enabled: true,
},
types.Route{
Model: gorm.Model{
@ -432,6 +436,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[1],
},
IsPrimary: false,
Enabled: true,
},
},
want: nil,
@ -448,6 +453,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[1],
},
IsPrimary: true,
Enabled: true,
},
routes: types.Routes{
types.Route{
@ -459,6 +465,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0],
},
IsPrimary: false,
Enabled: true,
},
types.Route{
Model: gorm.Model{
@ -469,6 +476,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[1],
},
IsPrimary: true,
Enabled: true,
},
types.Route{
Model: gorm.Model{
@ -479,8 +487,14 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[2],
},
IsPrimary: false,
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: true,
machineKeys[1]: true,
machineKeys[2]: true,
},
want: []key.MachinePublic{
machineKeys[1],
machineKeys[0],
@ -498,6 +512,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0],
},
IsPrimary: true,
Enabled: true,
},
routes: types.Routes{
types.Route{
@ -509,6 +524,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0],
},
IsPrimary: true,
Enabled: true,
},
// Offline
types.Route{
@ -520,8 +536,13 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[3],
},
IsPrimary: false,
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: true,
machineKeys[3]: false,
},
want: nil,
wantErr: false,
},
@ -536,6 +557,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0],
},
IsPrimary: true,
Enabled: true,
},
routes: types.Routes{
types.Route{
@ -547,6 +569,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[0],
},
IsPrimary: true,
Enabled: true,
},
// Offline
types.Route{
@ -558,6 +581,7 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[3],
},
IsPrimary: false,
Enabled: true,
},
types.Route{
Model: gorm.Model{
@ -568,14 +592,61 @@ func TestFailoverRoute(t *testing.T) {
MachineKey: machineKeys[1],
},
IsPrimary: true,
Enabled: true,
},
},
isConnected: map[key.MachinePublic]bool{
machineKeys[0]: false,
machineKeys[1]: true,
machineKeys[3]: false,
},
want: []key.MachinePublic{
machineKeys[0],
machineKeys[1],
},
wantErr: false,
},
{
name: "failover-primary-none-enabled",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
Enabled: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
Enabled: true,
},
// not enabled
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: false,
Enabled: false,
},
},
want: nil,
wantErr: false,
},
}
for _, tt := range tests {
@ -583,13 +654,14 @@ func TestFailoverRoute(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "failover-db-test")
assert.NoError(t, err)
notif := notifier.NewNotifier()
db, err = NewHeadscaleDatabase(
"sqlite3",
tmpDir+"/headscale_test.db",
false,
notif,
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
notifier.NewNotifier(),
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
@ -597,23 +669,15 @@ func TestFailoverRoute(t *testing.T) {
)
assert.NoError(t, err)
// Pretend that all the nodes are connected to control
for idx, key := range machineKeys {
// Pretend one node is offline
if idx == 3 {
continue
}
notif.AddNode(key, sink)
}
for _, route := range tt.routes {
if err := db.db.Save(&route).Error; err != nil {
if err := db.DB.Save(&route).Error; err != nil {
t.Fatalf("failed to create route: %s", err)
}
}
got, err := db.failoverRoute(&tt.failingRoute)
got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) {
return failoverRoute(tx, tt.isConnected, &tt.failingRoute)
})
if (err != nil) != tt.wantErr {
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
@ -627,3 +691,231 @@ func TestFailoverRoute(t *testing.T) {
})
}
}
// func TestDisableRouteFailover(t *testing.T) {
// machineKeys := []key.MachinePublic{
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// key.NewMachine().Public(),
// }
// tests := []struct {
// name string
// nodes types.Nodes
// routeID uint64
// isConnected map[key.MachinePublic]bool
// wantMachineKey key.MachinePublic
// wantErr string
// }{
// {
// name: "single-route",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// Node: types.Node{
// MachineKey: machineKeys[0],
// },
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// routeID: 1,
// wantMachineKey: machineKeys[0],
// },
// {
// name: "failover-simple",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// {
// name: "no-failover-offline",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// isConnected: map[key.MachinePublic]bool{
// machineKeys[0]: true,
// machineKeys[1]: false,
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// {
// name: "failover-to-online",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// isConnected: map[key.MachinePublic]bool{
// machineKeys[0]: true,
// machineKeys[1]: true,
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "")
// assert.NoError(t, err)
// // bootstrap db
// datab.DB.Transaction(func(tx *gorm.DB) error {
// for _, node := range tt.nodes {
// err := tx.Save(node).Error
// if err != nil {
// return err
// }
// _, err = SaveNodeRoutes(tx, node)
// if err != nil {
// return err
// }
// }
// return nil
// })
// got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
// return DisableRoute(tx, tt.routeID, tt.isConnected)
// })
// // if (err.Error() != "") != tt.wantErr {
// // t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
// // return
// // }
// if len(got.ChangeNodes) != 1 {
// t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes))
// }
// if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" {
// t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff)
// }
// })
// }
// }

View file

@ -7,6 +7,7 @@ import (
"testing"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"gopkg.in/check.v1"
)
@ -45,9 +46,12 @@ func (s *Suite) ResetDB(c *check.C) {
log.Printf("database path: %s", tmpDir+"/headscale_test.db")
db, err = NewHeadscaleDatabase(
"sqlite3",
tmpDir+"/headscale_test.db",
false,
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
notifier.NewNotifier(),
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),

View file

@ -15,22 +15,25 @@ var (
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
)
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.User, error) {
return CreateUser(tx, name)
})
}
// CreateUser creates a new User. Returns error if could not be created
// or another user already exists.
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
err := util.CheckForFQDNRules(name)
if err != nil {
return nil, err
}
user := types.User{}
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
if err := tx.Where("name = ?", name).First(&user).Error; err == nil {
return nil, ErrUserExists
}
user.Name = name
if err := hsdb.db.Create(&user).Error; err != nil {
if err := tx.Create(&user).Error; err != nil {
log.Error().
Str("func", "CreateUser").
Err(err).
@ -42,18 +45,21 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
return &user, nil
}
func (hsdb *HSDatabase) DestroyUser(name string) error {
return hsdb.Write(func(tx *gorm.DB) error {
return DestroyUser(tx, name)
})
}
// DestroyUser destroys a User. Returns error if the User does
// not exist or if there are nodes associated with it.
func (hsdb *HSDatabase) DestroyUser(name string) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
user, err := hsdb.getUser(name)
func DestroyUser(tx *gorm.DB, name string) error {
user, err := GetUser(tx, name)
if err != nil {
return ErrUserNotFound
}
nodes, err := hsdb.listNodesByUser(name)
nodes, err := ListNodesByUser(tx, name)
if err != nil {
return err
}
@ -61,32 +67,35 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
return ErrUserStillHasNodes
}
keys, err := hsdb.listPreAuthKeys(name)
keys, err := ListPreAuthKeys(tx, name)
if err != nil {
return err
}
for _, key := range keys {
err = hsdb.destroyPreAuthKey(key)
err = DestroyPreAuthKey(tx, key)
if err != nil {
return err
}
}
if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil {
if result := tx.Unscoped().Delete(&user); result.Error != nil {
return result.Error
}
return nil
}
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
return hsdb.Write(func(tx *gorm.DB) error {
return RenameUser(tx, oldName, newName)
})
}
// RenameUser renames a User. Returns error if the User does
// not exist or if another User exists with the new name.
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
func RenameUser(tx *gorm.DB, oldName, newName string) error {
var err error
oldUser, err := hsdb.getUser(oldName)
oldUser, err := GetUser(tx, oldName)
if err != nil {
return err
}
@ -94,7 +103,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
if err != nil {
return err
}
_, err = hsdb.getUser(newName)
_, err = GetUser(tx, newName)
if err == nil {
return ErrUserExists
}
@ -104,24 +113,22 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
oldUser.Name = newName
if result := hsdb.db.Save(&oldUser); result.Error != nil {
if result := tx.Save(&oldUser); result.Error != nil {
return result.Error
}
return nil
}
// GetUser fetches a user by name.
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getUser(name)
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUser(rx, name)
})
}
func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
func GetUser(tx *gorm.DB, name string) (*types.User, error) {
user := types.User{}
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
if result := tx.First(&user, "name = ?", name); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
@ -131,17 +138,16 @@ func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
return &user, nil
}
// ListUsers gets all the existing users.
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listUsers()
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
return ListUsers(rx)
})
}
func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
// ListUsers gets all the existing users.
func ListUsers(tx *gorm.DB) ([]types.User, error) {
users := []types.User{}
if err := hsdb.db.Find(&users).Error; err != nil {
if err := tx.Find(&users).Error; err != nil {
return nil, err
}
@ -149,46 +155,42 @@ func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
}
// ListNodesByUser gets all the nodes in a given user.
func (hsdb *HSDatabase) ListNodesByUser(name string) (types.Nodes, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.listNodesByUser(name)
}
func (hsdb *HSDatabase) listNodesByUser(name string) (types.Nodes, error) {
func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
err := util.CheckForFQDNRules(name)
if err != nil {
return nil, err
}
user, err := hsdb.getUser(name)
user, err := GetUser(tx, name)
if err != nil {
return nil, err
}
nodes := types.Nodes{}
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
return nil, err
}
return nodes, nil
}
// AssignNodeToUser assigns a Node to a user.
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.Write(func(tx *gorm.DB) error {
return AssignNodeToUser(tx, node, username)
})
}
// AssignNodeToUser assigns a Node to a user.
func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
err := util.CheckForFQDNRules(username)
if err != nil {
return err
}
user, err := hsdb.getUser(username)
user, err := GetUser(tx, username)
if err != nil {
return err
}
node.User = *user
if result := hsdb.db.Save(&node); result.Error != nil {
if result := tx.Save(&node); result.Error != nil {
return result.Error
}

View file

@ -36,7 +36,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
err = db.DestroyUser("test")
c.Assert(err, check.IsNil)
result := db.db.Preload("User").First(&pak, "key = ?", pak.Key)
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
// destroying a user also deletes all associated preauthkeys
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
@ -53,7 +53,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
err = db.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserStillHasNodes)
@ -105,7 +105,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
db.db.Save(&node)
db.DB.Save(&node)
c.Assert(node.UserID, check.Equals, oldUser.ID)
err = db.AssignNodeToUser(&node, newUser.Name)

View file

@ -84,6 +84,8 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
RegionID: d.cfg.ServerRegionID,
HostName: host,
DERPPort: port,
IPv4: d.cfg.IPv4,
IPv6: d.cfg.IPv6,
},
},
}
@ -99,6 +101,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
localDERPregion.Nodes[0].STUNPort = portSTUN
log.Info().Caller().Msgf("DERP region: %+v", localDERPregion)
log.Info().Caller().Msgf("DERP Nodes[0]: %+v", localDERPregion.Nodes[0])
return localDERPregion, nil
}
@ -208,6 +211,7 @@ func DERPProbeHandler(
// The initial implementation is here https://github.com/tailscale/tailscale/pull/1406
// They have a cache, but not clear if that is really necessary at Headscale, uh, scale.
// An example implementation is found here https://derp.tailscale.com/bootstrap-dns
// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL.
func DERPBootstrapDNSHandler(
derpMap *tailcfg.DERPMap,
) func(http.ResponseWriter, *http.Request) {

View file

@ -8,11 +8,13 @@ import (
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
@ -136,12 +138,14 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
ctx context.Context,
request *v1.ExpirePreAuthKeyRequest,
) (*v1.ExpirePreAuthKeyResponse, error) {
preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key)
err := api.h.db.DB.Transaction(func(tx *gorm.DB) error {
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
if err != nil {
return nil, err
return err
}
err = api.h.db.ExpirePreAuthKey(preAuthKey)
return db.ExpirePreAuthKey(tx, preAuthKey)
})
if err != nil {
return nil, err
}
@ -181,17 +185,31 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, err
}
node, err := api.h.db.RegisterNodeFromAuthCallback(
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
return db.RegisterNodeFromAuthCallback(
tx,
api.h.registrationCache,
mkey,
request.GetUser(),
nil,
util.RegisterMethodCLI,
api.h.cfg.IPPrefixes,
)
})
if err != nil {
return nil, err
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from api.RegisterNode",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-registernode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
}
@ -217,25 +235,35 @@ func (api headscaleV1APIServer) SetTags(
ctx context.Context,
request *v1.SetTagsRequest,
) (*v1.SetTagsResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
for _, tag := range request.GetTags() {
err := validateTag(tag)
if err != nil {
return nil, err
}
}
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.SetTags(tx, request.GetNodeId(), request.GetTags())
if err != nil {
return nil, err
}
for _, tag := range request.GetTags() {
err := validateTag(tag)
return db.GetNodeByID(tx, request.GetNodeId())
})
if err != nil {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
}
err = api.h.db.SetTags(node, request.GetTags())
if err != nil {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(codes.Internal, err.Error())
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from api.SetTags",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace().
@ -270,11 +298,21 @@ func (api headscaleV1APIServer) DeleteNode(
err = api.h.db.DeleteNode(
node,
api.h.nodeNotifier.ConnectedMap(),
)
if err != nil {
return nil, err
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, stateUpdate)
}
return &v1.DeleteNodeResponse{}, nil
}
@ -282,17 +320,38 @@ func (api headscaleV1APIServer) ExpireNode(
ctx context.Context,
request *v1.ExpireNodeRequest,
) (*v1.ExpireNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
now := time.Now()
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
db.NodeSetExpiry(
tx,
request.GetNodeId(),
now,
)
return db.GetNodeByID(tx, request.GetNodeId())
})
if err != nil {
return nil, err
}
now := time.Now()
selfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if selfUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByMachineKey(
ctx,
selfUpdate,
node.MachineKey)
}
api.h.db.NodeSetExpiry(
node,
now,
)
stateUpdate := types.StateUpdateExpire(node.ID, now)
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace().
Str("node", node.Hostname).
@ -306,19 +365,32 @@ func (api headscaleV1APIServer) RenameNode(
ctx context.Context,
request *v1.RenameNodeRequest,
) (*v1.RenameNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
if err != nil {
return nil, err
}
err = api.h.db.RenameNode(
node,
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.RenameNode(
tx,
request.GetNodeId(),
request.GetNewName(),
)
if err != nil {
return nil, err
}
return db.GetNodeByID(tx, request.GetNodeId())
})
if err != nil {
return nil, err
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from api.RenameNode",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace().
Str("node", node.Hostname).
Str("new_name", request.GetNewName()).
@ -331,8 +403,11 @@ func (api headscaleV1APIServer) ListNodes(
ctx context.Context,
request *v1.ListNodesRequest,
) (*v1.ListNodesResponse, error) {
isConnected := api.h.nodeNotifier.ConnectedMap()
if request.GetUser() != "" {
nodes, err := api.h.db.ListNodesByUser(request.GetUser())
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
return db.ListNodesByUser(rx, request.GetUser())
})
if err != nil {
return nil, err
}
@ -343,7 +418,7 @@ func (api headscaleV1APIServer) ListNodes(
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
resp.Online = isConnected[node.MachineKey]
response[index] = resp
}
@ -362,10 +437,10 @@ func (api headscaleV1APIServer) ListNodes(
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
resp.Online = isConnected[node.MachineKey]
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
&node,
node,
)
resp.InvalidTags = invalidTags
resp.ValidTags = validTags
@ -396,7 +471,9 @@ func (api headscaleV1APIServer) GetRoutes(
ctx context.Context,
request *v1.GetRoutesRequest,
) (*v1.GetRoutesResponse, error) {
routes, err := api.h.db.GetRoutes()
routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) {
return db.GetRoutes(rx)
})
if err != nil {
return nil, err
}
@ -410,11 +487,19 @@ func (api headscaleV1APIServer) EnableRoute(
ctx context.Context,
request *v1.EnableRouteRequest,
) (*v1.EnableRouteResponse, error) {
err := api.h.db.EnableRoute(request.GetRouteId())
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.EnableRoute(tx, request.GetRouteId())
})
if err != nil {
return nil, err
}
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
api.h.nodeNotifier.NotifyAll(
ctx, *update)
}
return &v1.EnableRouteResponse{}, nil
}
@ -422,11 +507,19 @@ func (api headscaleV1APIServer) DisableRoute(
ctx context.Context,
request *v1.DisableRouteRequest,
) (*v1.DisableRouteResponse, error) {
err := api.h.db.DisableRoute(request.GetRouteId())
isConnected := api.h.nodeNotifier.ConnectedMap()
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.DisableRoute(tx, request.GetRouteId(), isConnected)
})
if err != nil {
return nil, err
}
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, *update)
}
return &v1.DisableRouteResponse{}, nil
}
@ -453,11 +546,19 @@ func (api headscaleV1APIServer) DeleteRoute(
ctx context.Context,
request *v1.DeleteRouteRequest,
) (*v1.DeleteRouteResponse, error) {
err := api.h.db.DeleteRoute(request.GetRouteId())
isConnected := api.h.nodeNotifier.ConnectedMap()
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
})
if err != nil {
return nil, err
}
if update != nil && update.Valid() {
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
api.h.nodeNotifier.NotifyWithIgnore(ctx, *update)
}
return &v1.DeleteRouteResponse{}, nil
}

View file

@ -272,13 +272,26 @@ func (m *Mapper) LiteMapResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
pol *policy.ACLPolicy,
messages ...string,
) ([]byte, error) {
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress)
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
pol,
node,
nodeMapToList(m.peers),
)
if err != nil {
return nil, err
}
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
resp.SSHPolicy = sshPolicy
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
func (m *Mapper) KeepAliveResponse(
@ -380,9 +393,7 @@ func (m *Mapper) PeerChangedPatchResponse(
}
if patches, ok := m.patches[uint64(change.NodeID)]; ok {
patches := append(patches, p)
m.patches[uint64(change.NodeID)] = patches
m.patches[uint64(change.NodeID)] = append(patches, p)
} else {
m.patches[uint64(change.NodeID)] = []patch{p}
}
@ -458,6 +469,8 @@ func (m *Mapper) marshalMapResponse(
switch {
case resp.Peers != nil && len(resp.Peers) > 0:
responseType = "full"
case isSelfUpdate(messages...):
responseType = "self"
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil:
responseType = "lite"
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
@ -656,3 +669,13 @@ func appendPeerChanges(
return nil
}
func isSelfUpdate(messages ...string) bool {
for _, message := range messages {
if strings.Contains(message, types.SelfUpdateIdentifier) {
return true
}
}
return false
}

View file

@ -72,7 +72,7 @@ func tailNode(
}
var derp string
if node.Hostinfo.NetInfo != nil {
if node.Hostinfo != nil && node.Hostinfo.NetInfo != nil {
derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP)
} else {
derp = "127.3.3.40:0" // Zero means disconnected or unknown.

View file

@ -1,6 +1,7 @@
package notifier
import (
"context"
"fmt"
"strings"
"sync"
@ -14,24 +15,28 @@ import (
type Notifier struct {
l sync.RWMutex
nodes map[string]chan<- types.StateUpdate
connected map[key.MachinePublic]bool
}
func NewNotifier() *Notifier {
return &Notifier{}
return &Notifier{
nodes: make(map[string]chan<- types.StateUpdate),
connected: make(map[key.MachinePublic]bool),
}
}
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to add node")
defer log.Trace().
Caller().
Str("key", machineKey.ShortString()).
Msg("releasing lock to add node")
n.l.Lock()
defer n.l.Unlock()
if n.nodes == nil {
n.nodes = make(map[string]chan<- types.StateUpdate)
}
n.nodes[machineKey.String()] = c
n.connected[machineKey] = true
log.Trace().
Str("machine_key", machineKey.ShortString()).
@ -41,16 +46,20 @@ func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpd
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to remove node")
defer log.Trace().
Caller().
Str("key", machineKey.ShortString()).
Msg("releasing lock to remove node")
n.l.Lock()
defer n.l.Unlock()
if n.nodes == nil {
if len(n.nodes) == 0 {
return
}
delete(n.nodes, machineKey.String())
n.connected[machineKey] = false
log.Trace().
Str("machine_key", machineKey.ShortString()).
@ -64,23 +73,28 @@ func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
n.l.RLock()
defer n.l.RUnlock()
if _, ok := n.nodes[machineKey.String()]; ok {
return true
}
return false
return n.connected[machineKey]
}
func (n *Notifier) NotifyAll(update types.StateUpdate) {
n.NotifyWithIgnore(update)
// TODO(kradalby): This returns a pointer and can be dangerous.
func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool {
return n.connected
}
func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) {
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
n.NotifyWithIgnore(ctx, update)
}
func (n *Notifier) NotifyWithIgnore(
ctx context.Context,
update types.StateUpdate,
ignore ...string,
) {
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
defer log.Trace().
Caller().
Interface("type", update.Type).
Msg("releasing lock, finished notifing")
Msg("releasing lock, finished notifying")
n.l.RLock()
defer n.l.RUnlock()
@ -90,23 +104,58 @@ func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string)
continue
}
log.Trace().Caller().Str("machine", key).Strs("ignoring", ignore).Msg("sending update")
c <- update
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Str("mkey", key).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled")
return
case c <- update:
log.Trace().
Str("mkey", key).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
}
}
}
func (n *Notifier) NotifyByMachineKey(update types.StateUpdate, mKey key.MachinePublic) {
func (n *Notifier) NotifyByMachineKey(
ctx context.Context,
update types.StateUpdate,
mKey key.MachinePublic,
) {
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
defer log.Trace().
Caller().
Interface("type", update.Type).
Msg("releasing lock, finished notifing")
Msg("releasing lock, finished notifying")
n.l.RLock()
defer n.l.RUnlock()
if c, ok := n.nodes[mKey.String()]; ok {
c <- update
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Str("mkey", mKey.String()).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled")
return
case c <- update:
log.Trace().
Str("mkey", mKey.String()).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
}
}
}

View file

@ -21,6 +21,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"gorm.io/gorm"
"tailscale.com/types/key"
)
@ -569,7 +570,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
Str("node", node.Hostname).
Msg("node already registered, reauthenticating")
err := h.db.NodeSetExpiry(node, expiry)
err := h.db.NodeSetExpiry(node.ID, expiry)
if err != nil {
util.LogErr(err, "Failed to refresh node")
http.Error(
@ -623,6 +624,12 @@ func (h *Headscale) validateNodeForOIDCCallback(
util.LogErr(err, "Failed to write response")
}
stateUpdate := types.StateUpdateExpire(node.ID, expiry)
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
return nil, true, nil
}
@ -712,14 +719,22 @@ func (h *Headscale) registerNodeForOIDCCallback(
machineKey *key.MachinePublic,
expiry time.Time,
) error {
if _, err := h.db.RegisterNodeFromAuthCallback(
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
if _, err := db.RegisterNodeFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules
tx,
h.registrationCache,
*machineKey,
user.Name,
&expiry,
util.RegisterMethodOIDC,
h.cfg.IPPrefixes,
); err != nil {
return err
}
return nil
}); err != nil {
util.LogErr(err, "could not register node")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)

View file

@ -250,6 +250,21 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
if node.IPAddresses.InIPSet(expanded) {
dests = append(dests, dest)
}
// If the node exposes routes, ensure they are note removed
// when the filters are reduced.
if node.Hostinfo != nil {
// TODO(kradalby): Evaluate if we should only keep
// the routes if the route is enabled. This will
// require database access in this part of the code.
if len(node.Hostinfo.RoutableIPs) > 0 {
for _, routableIP := range node.Hostinfo.RoutableIPs {
if expanded.ContainsPrefix(routableIP) {
dests = append(dests, dest)
}
}
}
}
}
if len(dests) > 0 {
@ -890,8 +905,14 @@ func (pol *ACLPolicy) TagsOfNode(
validTags := make([]string, 0)
invalidTags := make([]string, 0)
// TODO(kradalby): Why is this sometimes nil? coming from tailNode?
if node == nil {
return validTags, invalidTags
}
validTagMap := make(map[string]bool)
invalidTagMap := make(map[string]bool)
if node.Hostinfo != nil {
for _, tag := range node.Hostinfo.RequestTags {
owners, err := expandOwnersFromTag(pol, tag)
if errors.Is(err, ErrInvalidTag) {
@ -917,6 +938,7 @@ func (pol *ACLPolicy) TagsOfNode(
for tag := range validTagMap {
validTags = append(validTags, tag)
}
}
return validTags, invalidTags
}

View file

@ -1901,6 +1901,81 @@ func TestReduceFilterRules(t *testing.T) {
},
want: []tailcfg.FilterRule{},
},
{
name: "1604-subnet-routers-are-preserved",
pol: ACLPolicy{
Groups: Groups{
"group:admins": {"user1"},
},
ACLs: []ACL{
{
Action: "accept",
Sources: []string{"group:admins"},
Destinations: []string{"group:admins:*"},
},
{
Action: "accept",
Sources: []string{"group:admins"},
Destinations: []string{"10.33.0.0/16:*"},
},
},
},
node: &types.Node{
IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.1"),
netip.MustParseAddr("fd7a:115c:a1e0::1"),
},
User: types.User{Name: "user1"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
},
},
},
peers: types.Nodes{
&types.Node{
IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.2"),
netip.MustParseAddr("fd7a:115c:a1e0::2"),
},
User: types.User{Name: "user1"},
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.1/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::1/128",
Ports: tailcfg.PortRangeAny,
},
},
},
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "10.33.0.0/16",
Ports: tailcfg.PortRangeAny,
},
},
},
},
},
}
for _, tt := range tests {

View file

@ -4,12 +4,15 @@ import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
xslices "golang.org/x/exp/slices"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
@ -125,6 +128,18 @@ func (h *Headscale) handlePoll(
return
}
if h.ACLPolicy != nil {
// update routes with peer information
update, err := h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
if err != nil {
logErr(err, "Error running auto approved routes")
}
if update != nil {
sendUpdate = true
}
}
}
// Services is mostly useful for discovery and not critical,
@ -138,41 +153,68 @@ func (h *Headscale) handlePoll(
}
if sendUpdate {
if err := h.db.NodeSave(node); err != nil {
if err := h.db.DB.Save(node).Error; err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
return
}
// Send an update to all peers to propagate the new routes
// available.
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from handlePoll -> update -> new hostinfo",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-hostinfochange", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(
ctx,
stateUpdate,
node.MachineKey.String())
}
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
selfUpdate := types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: types.Nodes{node},
}
if selfUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", node.Hostname)
h.nodeNotifier.NotifyByMachineKey(
ctx,
selfUpdate,
node.MachineKey)
}
return
}
}
if err := h.db.NodeSave(node); err != nil {
if err := h.db.DB.Save(node).Error; err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
return
}
// TODO(kradalby): Figure out why patch changes does
// not show up in output from `tailscale debug netmap`.
// stateUpdate := types.StateUpdate{
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{&change},
// }
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{&change},
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(
ctx,
stateUpdate,
node.MachineKey.String())
}
@ -228,7 +270,7 @@ func (h *Headscale) handlePoll(
}
}
if err := h.db.NodeSave(node); err != nil {
if err := h.db.DB.Save(node).Error; err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
@ -265,7 +307,10 @@ func (h *Headscale) handlePoll(
// update ACLRules with peer informations (to update server tags if necessary)
if h.ACLPolicy != nil {
// update routes with peer information
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
// This state update is ignored as it will be sent
// as part of the whole node
// TODO(kradalby): figure out if that is actually correct
_, err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node)
if err != nil {
logErr(err, "Error running auto approved routes")
}
@ -301,11 +346,17 @@ func (h *Headscale) handlePoll(
Message: "called from handlePoll -> new node added",
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "poll-newnode-peers", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(
ctx,
stateUpdate,
node.MachineKey.String())
}
if len(node.Routes) > 0 {
go h.pollFailoverRoutes(logErr, "new node", node)
}
// Set up the client stream
h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done()
@ -323,15 +374,9 @@ func (h *Headscale) handlePoll(
keepAliveTicker := time.NewTicker(keepAliveInterval)
ctx = context.WithValue(ctx, nodeNameContextKey, node.Hostname)
ctx, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(context.WithValue(ctx, nodeNameContextKey, node.Hostname))
defer cancel()
if len(node.Routes) > 0 {
go h.db.EnsureFailoverRouteIsAvailable(node)
}
for {
logInfo("Waiting for update on stream channel")
select {
@ -370,6 +415,17 @@ func (h *Headscale) handlePoll(
var data []byte
var err error
// Ensure the node object is updated, for example, there
// might have been a hostinfo update in a sidechannel
// which contains data needed to generate a map response.
node, err = h.db.GetNodeByMachineKey(node.MachineKey)
if err != nil {
logErr(err, "Could not get machine from db")
return
}
startMapResp := time.Now()
switch update.Type {
case types.StateFullUpdate:
logInfo("Sending Full MapResponse")
@ -378,6 +434,7 @@ func (h *Headscale) handlePoll(
case types.StatePeerChanged:
logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message))
isConnectedMap := h.nodeNotifier.ConnectedMap()
for _, node := range update.ChangeNodes {
// If a node is not reported to be online, it might be
// because the value is outdated, check with the notifier.
@ -385,7 +442,7 @@ func (h *Headscale) handlePoll(
// this might be because it has announced itself, but not
// reached the stage to actually create the notifier channel.
if node.IsOnline != nil && !*node.IsOnline {
isOnline := h.nodeNotifier.IsConnected(node.MachineKey)
isOnline := isConnectedMap[node.MachineKey]
node.IsOnline = &isOnline
}
}
@ -401,7 +458,7 @@ func (h *Headscale) handlePoll(
if len(update.ChangeNodes) == 1 {
logInfo("Sending SelfUpdate MapResponse")
node = update.ChangeNodes[0]
data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy)
data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy, types.SelfUpdateIdentifier)
} else {
logInfo("SelfUpdate contained too many nodes, this is likely a bug in the code, please report.")
}
@ -416,8 +473,11 @@ func (h *Headscale) handlePoll(
return
}
log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startMapResp).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished making map response")
// Only send update if there is change
if data != nil {
startWrite := time.Now()
_, err = writer.Write(data)
if err != nil {
logErr(err, "Could not write the map response")
@ -435,6 +495,7 @@ func (h *Headscale) handlePoll(
return
}
log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished writing mapresp to node")
log.Info().
Caller().
@ -454,7 +515,7 @@ func (h *Headscale) handlePoll(
go h.updateNodeOnlineStatus(false, node)
// Failover the node's routes if any.
go h.db.FailoverNodeRoutesWithNotify(node)
go h.pollFailoverRoutes(logErr, "node closing connection", node)
// The connection has been closed, so we can stop polling.
return
@ -467,6 +528,22 @@ func (h *Headscale) handlePoll(
}
}
func (h *Headscale) pollFailoverRoutes(logErr func(error, string), where string, node *types.Node) {
update, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.EnsureFailoverRouteIsAvailable(tx, h.nodeNotifier.ConnectedMap(), node)
})
if err != nil {
logErr(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
return
}
if update != nil && !update.Empty() && update.Valid() {
ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.MachineKey.String())
}
}
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
// about change in their online/offline status.
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
@ -486,10 +563,13 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
},
}
if statusUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String())
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, statusUpdate, node.MachineKey.String())
}
err := h.db.UpdateLastSeen(node)
err := h.db.DB.Transaction(func(tx *gorm.DB) error {
return db.UpdateLastSeen(tx, node.ID, *node.LastSeen)
})
if err != nil {
log.Error().Err(err).Msg("Cannot update node LastSeen")

View file

@ -13,7 +13,7 @@ import (
)
const (
MinimumCapVersion tailcfg.CapabilityVersion = 56
MinimumCapVersion tailcfg.CapabilityVersion = 58
)
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol

View file

@ -41,8 +41,12 @@ func (s *Suite) ResetDB(c *check.C) {
}
cfg := types.Config{
NoisePrivateKeyPath: tmpDir + "/noise_private.key",
DBtype: "sqlite3",
DBpath: tmpDir + "/headscale_test.db",
Database: types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
IPPrefixes: []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},

View file

@ -1,15 +1,23 @@
package types
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net/netip"
"time"
"tailscale.com/tailcfg"
)
const (
SelfUpdateIdentifier = "self-update"
DatabasePostgres = "postgres"
DatabaseSqlite = "sqlite3"
)
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
type IPPrefix netip.Prefix
@ -150,7 +158,9 @@ func (su *StateUpdate) Valid() bool {
}
case StateSelfUpdate:
if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 {
panic("Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node")
panic(
"Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node",
)
}
case StateDERPUpdated:
if su.DERPMap == nil {
@ -160,3 +170,37 @@ func (su *StateUpdate) Valid() bool {
return true
}
// Empty reports if there are any updates in the StateUpdate.
func (su *StateUpdate) Empty() bool {
switch su.Type {
case StatePeerChanged:
return len(su.ChangeNodes) == 0
case StatePeerChangedPatch:
return len(su.ChangePatches) == 0
case StatePeerRemoved:
return len(su.Removed) == 0
}
return false
}
func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate {
return StateUpdate{
Type: StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(nodeID),
KeyExpiry: &expiry,
},
},
}
}
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
ctx2, _ := context.WithTimeout(
context.WithValue(context.WithValue(ctx, "hostname", hostname), "origin", origin),
3*time.Second,
)
return ctx2
}

View file

@ -11,7 +11,6 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@ -20,6 +19,8 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"github.com/juanfont/headscale/hscontrol/util"
)
const (
@ -46,16 +47,9 @@ type Config struct {
Log LogConfig
DisableUpdateCheck bool
DERP DERPConfig
Database DatabaseConfig
DBtype string
DBpath string
DBhost string
DBport int
DBname string
DBuser string
DBpass string
DBssl string
DERP DERPConfig
TLS TLSConfig
@ -77,6 +71,31 @@ type Config struct {
ACL ACLConfig
}
type SqliteConfig struct {
Path string
}
type PostgresConfig struct {
Host string
Port int
Name string
User string
Pass string
Ssl string
MaxOpenConnections int
MaxIdleConnections int
ConnMaxIdleTimeSecs int
}
type DatabaseConfig struct {
// Type sets the database type, either "sqlite3" or "postgres"
Type string
Debug bool
Sqlite SqliteConfig
Postgres PostgresConfig
}
type TLSConfig struct {
CertPath string
KeyPath string
@ -112,6 +131,7 @@ type OIDCConfig struct {
type DERPConfig struct {
ServerEnabled bool
AutomaticallyAddEmbeddedDerpRegion bool
ServerRegionID int
ServerRegionCode string
ServerRegionName string
@ -121,6 +141,8 @@ type DERPConfig struct {
Paths []string
AutoUpdate bool
UpdateFrequency time.Duration
IPv4 string
IPv6 string
}
type LogTailConfig struct {
@ -162,6 +184,19 @@ func LoadConfig(path string, isFile bool) error {
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.AutomaticEnv()
viper.RegisterAlias("db_type", "database.type")
// SQLite aliases
viper.RegisterAlias("db_path", "database.sqlite.path")
// Postgres aliases
viper.RegisterAlias("db_host", "database.postgres.host")
viper.RegisterAlias("db_port", "database.postgres.port")
viper.RegisterAlias("db_name", "database.postgres.name")
viper.RegisterAlias("db_user", "database.postgres.user")
viper.RegisterAlias("db_pass", "database.postgres.pass")
viper.RegisterAlias("db_ssl", "database.postgres.ssl")
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType)
@ -173,6 +208,7 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("derp.server.enabled", false)
viper.SetDefault("derp.server.stun.enabled", true)
viper.SetDefault("derp.server.automatically_add_embedded_derp_region", true)
viper.SetDefault("unix_socket", "/var/run/headscale/headscale.sock")
viper.SetDefault("unix_socket_permission", "0o770")
@ -184,6 +220,10 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("cli.insecure", false)
viper.SetDefault("db_ssl", false)
viper.SetDefault("database.postgres.ssl", false)
viper.SetDefault("database.postgres.max_open_conns", 10)
viper.SetDefault("database.postgres.max_idle_conns", 10)
viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600)
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
viper.SetDefault("oidc.strip_email_domain", true)
@ -262,7 +302,7 @@ func LoadConfig(path string, isFile bool) error {
}
if errorText != "" {
//nolint
// nolint
return errors.New(strings.TrimSuffix(errorText, "\n"))
} else {
return nil
@ -294,8 +334,14 @@ func GetDERPConfig() DERPConfig {
serverRegionCode := viper.GetString("derp.server.region_code")
serverRegionName := viper.GetString("derp.server.region_name")
stunAddr := viper.GetString("derp.server.stun_listen_addr")
privateKeyPath := util.AbsolutePathFromConfigPath(viper.GetString("derp.server.private_key_path"))
privateKeyPath := util.AbsolutePathFromConfigPath(
viper.GetString("derp.server.private_key_path"),
)
ipv4 := viper.GetString("derp.server.ipv4")
ipv6 := viper.GetString("derp.server.ipv6")
automaticallyAddEmbeddedDerpRegion := viper.GetBool(
"derp.server.automatically_add_embedded_derp_region",
)
if serverEnabled && stunAddr == "" {
log.Fatal().
Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true")
@ -318,6 +364,11 @@ func GetDERPConfig() DERPConfig {
paths := viper.GetStringSlice("derp.paths")
if serverEnabled && !automaticallyAddEmbeddedDerpRegion && len(paths) == 0 {
log.Fatal().
Msg("Disabling derp.server.automatically_add_embedded_derp_region requires to configure the derp server in derp.paths")
}
autoUpdate := viper.GetBool("derp.auto_update_enabled")
updateFrequency := viper.GetDuration("derp.update_frequency")
@ -332,6 +383,9 @@ func GetDERPConfig() DERPConfig {
Paths: paths,
AutoUpdate: autoUpdate,
UpdateFrequency: updateFrequency,
IPv4: ipv4,
IPv6: ipv6,
AutomaticallyAddEmbeddedDerpRegion: automaticallyAddEmbeddedDerpRegion,
}
}
@ -379,6 +433,45 @@ func GetLogConfig() LogConfig {
}
}
func GetDatabaseConfig() DatabaseConfig {
debug := viper.GetBool("database.debug")
type_ := viper.GetString("database.type")
switch type_ {
case DatabaseSqlite, DatabasePostgres:
break
case "sqlite":
type_ = "sqlite3"
default:
log.Fatal().
Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_)
}
return DatabaseConfig{
Type: type_,
Debug: debug,
Sqlite: SqliteConfig{
Path: util.AbsolutePathFromConfigPath(
viper.GetString("database.sqlite.path"),
),
},
Postgres: PostgresConfig{
Host: viper.GetString("database.postgres.host"),
Port: viper.GetInt("database.postgres.port"),
Name: viper.GetString("database.postgres.name"),
User: viper.GetString("database.postgres.user"),
Pass: viper.GetString("database.postgres.pass"),
Ssl: viper.GetString("database.postgres.ssl"),
MaxOpenConnections: viper.GetInt("database.postgres.max_open_conns"),
MaxIdleConnections: viper.GetInt("database.postgres.max_idle_conns"),
ConnMaxIdleTimeSecs: viper.GetInt(
"database.postgres.conn_max_idle_time_secs",
),
},
}
}
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{}
@ -580,7 +673,7 @@ func GetHeadscaleConfig() (*Config, error) {
if err != nil {
return nil, err
}
oidcClientSecret = string(secretBytes)
oidcClientSecret = strings.TrimSpace(string(secretBytes))
}
return &Config{
@ -607,14 +700,7 @@ func GetHeadscaleConfig() (*Config, error) {
"node_update_check_interval",
),
DBtype: viper.GetString("db_type"),
DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")),
DBhost: viper.GetString("db_host"),
DBport: viper.GetInt("db_port"),
DBname: viper.GetString("db_name"),
DBuser: viper.GetString("db_user"),
DBpass: viper.GetString("db_pass"),
DBssl: viper.GetString("db_ssl"),
Database: GetDatabaseConfig(),
TLS: GetTLSConfig(),

View file

@ -383,7 +383,7 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri
// inform peers about smaller changes to the node.
// When a field is added to this function, remember to also add it to:
// - node.ApplyPeerChange
// - logTracePeerChange in poll.go
// - logTracePeerChange in poll.go.
func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
ret := tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),

View file

@ -6,6 +6,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
@ -366,3 +367,110 @@ func TestPeerChangeFromMapRequest(t *testing.T) {
})
}
}
func TestApplyPeerChange(t *testing.T) {
tests := []struct {
name string
nodeBefore Node
change *tailcfg.PeerChange
want Node
}{
{
name: "hostinfo-and-netinfo-not-exists",
nodeBefore: Node{},
change: &tailcfg.PeerChange{
DERPRegion: 1,
},
want: Node{
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 1,
},
},
},
},
{
name: "hostinfo-netinfo-not-exists",
nodeBefore: Node{
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test",
},
},
change: &tailcfg.PeerChange{
DERPRegion: 3,
},
want: Node{
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test",
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 3,
},
},
},
},
{
name: "hostinfo-netinfo-exists-derp-set",
nodeBefore: Node{
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test",
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 999,
},
},
},
change: &tailcfg.PeerChange{
DERPRegion: 2,
},
want: Node{
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test",
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 2,
},
},
},
},
{
name: "endpoints-not-set",
nodeBefore: Node{},
change: &tailcfg.PeerChange{
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("8.8.8.8:88"),
},
},
want: Node{
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("8.8.8.8:88"),
},
},
},
{
name: "endpoints-set",
nodeBefore: Node{
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("6.6.6.6:66"),
},
},
change: &tailcfg.PeerChange{
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("8.8.8.8:88"),
},
},
want: Node{
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("8.8.8.8:88"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.nodeBefore.ApplyPeerChange(tt.change)
if diff := cmp.Diff(tt.want, tt.nodeBefore, util.Comparers...); diff != "" {
t.Errorf("Patch unexpected result (-want +got):\n%s", diff)
}
})
}
}

View file

@ -2,7 +2,6 @@ package types
import (
"strconv"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
@ -25,9 +24,10 @@ func (n *User) TailscaleUser() *tailcfg.User {
ID: tailcfg.UserID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
// TODO(kradalby): See if we can fill in Gravatar here
ProfilePicURL: "",
Logins: []tailcfg.LoginID{},
Created: time.Time{},
Created: n.CreatedAt,
}
return &user
@ -38,6 +38,7 @@ func (n *User) TailscaleLogin() *tailcfg.Login {
ID: tailcfg.LoginID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
// TODO(kradalby): See if we can fill in Gravatar here
ProfilePicURL: "",
}

View file

@ -15,6 +15,10 @@ var IPComparer = cmp.Comparer(func(x, y netip.Addr) bool {
return x.Compare(y) == 0
})
var AddrPortComparer = cmp.Comparer(func(x, y netip.AddrPort) bool {
return x == y
})
var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool {
return x.String() == y.String()
})
@ -28,5 +32,5 @@ var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
})
var Comparers []cmp.Option = []cmp.Option{
IPComparer, PrefixComparer, MkeyComparer, NkeyComparer, DkeyComparer,
IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer,
}

View file

@ -265,6 +265,8 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
@ -325,6 +327,8 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})

View file

@ -53,6 +53,8 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
@ -90,6 +92,8 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})

View file

@ -1240,7 +1240,7 @@ func TestNodeRenameCommand(t *testing.T) {
assert.Contains(t, listAll[4].GetGivenName(), "node-5")
for idx := 0; idx < 3; idx++ {
_, err := headscale.Execute(
res, err := headscale.Execute(
[]string{
"headscale",
"nodes",
@ -1251,6 +1251,8 @@ func TestNodeRenameCommand(t *testing.T) {
},
)
assert.Nil(t, err)
assert.Contains(t, res, "Node renamed")
}
var listAllAfterRename []v1.Node

View file

@ -33,20 +33,23 @@ func TestDERPServerScenario(t *testing.T) {
defer scenario.Shutdown()
spec := map[string]int{
"user1": len(MustTestVersions),
"user1": 10,
// "user1": len(MustTestVersions),
}
headscaleConfig := map[string]string{}
headscaleConfig["HEADSCALE_DERP_URLS"] = ""
headscaleConfig["HEADSCALE_DERP_SERVER_ENABLED"] = "true"
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_ID"] = "999"
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_CODE"] = "headscale"
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP"
headscaleConfig["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478"
headscaleConfig["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key"
headscaleConfig := map[string]string{
"HEADSCALE_DERP_URLS": "",
"HEADSCALE_DERP_SERVER_ENABLED": "true",
"HEADSCALE_DERP_SERVER_REGION_ID": "999",
"HEADSCALE_DERP_SERVER_REGION_CODE": "headscale",
"HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP",
"HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478",
"HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key",
// Envknob for enabling DERP debug logs
headscaleConfig["DERP_DEBUG_LOGS"] = "true"
headscaleConfig["DERP_PROBER_DEBUG_LOGS"] = "true"
"DERP_DEBUG_LOGS": "true",
"DERP_PROBER_DEBUG_LOGS": "true",
}
err = scenario.CreateHeadscaleEnv(
spec,
@ -67,6 +70,8 @@ func TestDERPServerScenario(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)

View file

@ -26,12 +26,34 @@ func TestPingAllByIP(t *testing.T) {
assertNoErr(t, err)
defer scenario.Shutdown()
// TODO(kradalby): it does not look like the user thing works, only second
// get created? maybe only when many?
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyip"))
headscaleConfig := map[string]string{
"HEADSCALE_DERP_URLS": "",
"HEADSCALE_DERP_SERVER_ENABLED": "true",
"HEADSCALE_DERP_SERVER_REGION_ID": "999",
"HEADSCALE_DERP_SERVER_REGION_CODE": "headscale",
"HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP",
"HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478",
"HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key",
// Envknob for enabling DERP debug logs
"DERP_DEBUG_LOGS": "true",
"DERP_PROBER_DEBUG_LOGS": "true",
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{},
hsic.WithTestName("pingallbyip"),
hsic.WithConfigEnv(headscaleConfig),
hsic.WithTLS(),
hsic.WithHostnameAsServerURL(),
)
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
@ -43,6 +65,46 @@ func TestPingAllByIP(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
func TestPingAllByIPPublicDERP(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario()
assertNoErr(t, err)
defer scenario.Shutdown()
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{},
hsic.WithTestName("pingallbyippubderp"),
)
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
@ -73,6 +135,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
clientIPs := make(map[TailscaleClient][]netip.Addr)
for _, client := range allClients {
ips, err := client.IPs()
@ -112,6 +176,8 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allClients, err = scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
@ -263,6 +329,8 @@ func TestPingAllByHostname(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
@ -320,9 +388,13 @@ func TestTaildrop(t *testing.T) {
if err != nil {
t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err)
}
}
curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"}
curlCommand := []string{
"curl",
"--unix-socket",
"/var/run/tailscale/tailscaled.sock",
"http://local-tailscaled.sock/localapi/v0/file-targets",
}
err = retry(10, 1*time.Second, func() error {
result, _, err := client.Execute(curlCommand)
if err != nil {
@ -339,13 +411,23 @@ func TestTaildrop(t *testing.T) {
for _, ft := range fts {
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
}
return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr)
return fmt.Errorf(
"client %s does not have all its peers as FileTargets, got %d, want: %d\n%s",
client.Hostname(),
len(fts),
len(allClients)-1,
ftStr,
)
}
return err
})
if err != nil {
t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err)
t.Errorf(
"failed to query localapi for filetarget on %s, err: %s",
client.Hostname(),
err,
)
}
}
@ -457,6 +539,8 @@ func TestResolveMagicDNS(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
// Poor mans cache
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
@ -525,6 +609,8 @@ func TestExpireNode(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
@ -546,7 +632,7 @@ func TestExpireNode(t *testing.T) {
// TODO(kradalby): This is Headscale specific and would not play nicely
// with other implementations of the ControlServer interface
result, err := headscale.Execute([]string{
"headscale", "nodes", "expire", "--identifier", "0", "--output", "json",
"headscale", "nodes", "expire", "--identifier", "1", "--output", "json",
})
assertNoErr(t, err)
@ -577,16 +663,38 @@ func TestExpireNode(t *testing.T) {
assertNotNil(t, peerStatus.Expired)
assert.NotNil(t, peerStatus.KeyExpiry)
t.Logf("node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry)
t.Logf(
"node %q should have a key expire before %s, was %s",
peerStatus.HostName,
now.String(),
peerStatus.KeyExpiry,
)
if peerStatus.KeyExpiry != nil {
assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry)
assert.Truef(
t,
peerStatus.KeyExpiry.Before(now),
"node %q should have a key expire before %s, was %s",
peerStatus.HostName,
now.String(),
peerStatus.KeyExpiry,
)
}
assert.Truef(t, peerStatus.Expired, "node %q should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired)
assert.Truef(
t,
peerStatus.Expired,
"node %q should be expired, expired is %v",
peerStatus.HostName,
peerStatus.Expired,
)
_, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()})
if !strings.Contains(stderr, "node key has expired") {
t.Errorf("expected to be unable to ping expired host %q from %q", node.GetName(), client.Hostname())
t.Errorf(
"expected to be unable to ping expired host %q from %q",
node.GetName(),
client.Hostname(),
)
}
} else {
t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey)
@ -598,7 +706,7 @@ func TestExpireNode(t *testing.T) {
// NeedsLogin means that the node has understood that it is no longer
// valid.
assert.Equal(t, "NeedsLogin", status.BackendState)
assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName)
}
}
}
@ -627,6 +735,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
@ -691,7 +801,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
assert.Truef(
t,
lastSeen.After(lastSeenThreshold),
"lastSeen (%v) was not %s after the threshold (%v)",
"node (%s) lastSeen (%v) was not %s after the threshold (%v)",
node.GetName(),
lastSeen,
keepAliveInterval,
lastSeenThreshold,

View file

@ -110,8 +110,8 @@ func DefaultConfigEnv() map[string]string {
return map[string]string{
"HEADSCALE_LOG_LEVEL": "trace",
"HEADSCALE_ACL_POLICY_PATH": "",
"HEADSCALE_DB_TYPE": "sqlite3",
"HEADSCALE_DB_PATH": "/tmp/integration_test_db.sqlite3",
"HEADSCALE_DATABASE_TYPE": "sqlite",
"HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3",
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
"HEADSCALE_NODE_UPDATE_CHECK_INTERVAL": "10s",
"HEADSCALE_IP_PREFIXES": "fd7a:115c:a1e0::/48 100.64.0.0/10",

View file

@ -9,10 +9,15 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"tailscale.com/types/ipproto"
"tailscale.com/wgengine/filter"
)
// This test is both testing the routes command and the propagation of
@ -83,9 +88,9 @@ func TestEnablingRoutes(t *testing.T) {
assert.Len(t, routes, 3)
for _, route := range routes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary())
}
// Verify that no routes has been sent to the client,
@ -130,9 +135,9 @@ func TestEnablingRoutes(t *testing.T) {
assert.Len(t, enablingRoutes, 3)
for _, route := range enablingRoutes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, true, route.GetEnabled())
assert.Equal(t, true, route.GetIsPrimary())
}
time.Sleep(5 * time.Second)
@ -186,6 +191,8 @@ func TestEnablingRoutes(t *testing.T) {
})
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var disablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
@ -204,16 +211,14 @@ func TestEnablingRoutes(t *testing.T) {
assert.Equal(t, true, route.GetAdvertised())
if route.GetId() == routeToBeDisabled.GetId() {
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary())
} else {
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
assert.Equal(t, true, route.GetEnabled())
assert.Equal(t, true, route.GetIsPrimary())
}
}
time.Sleep(5 * time.Second)
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
@ -289,7 +294,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// advertise HA route on node 1 and 2
// ID 1 will be primary
// ID 2 will be secondary
for _, client := range allClients {
for _, client := range allClients[:2] {
status, err := client.Status()
assertNoErr(t, err)
@ -301,6 +306,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
} else {
t.Fatalf("failed to find route for Node %s (id: %s)", status.Self.HostName, status.Self.ID)
}
}
@ -323,6 +330,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
assertNoErr(t, err)
assert.Len(t, routes, 2)
t.Logf("initial routes %#v", routes)
for _, route := range routes {
assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled())
@ -639,6 +648,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
assertNoErr(t, err)
assert.Len(t, routesAfterDisabling1, 2)
t.Logf("routes after disabling1 %#v", routesAfterDisabling1)
// Node 1 is not primary
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled())
@ -778,3 +789,413 @@ func TestHASubnetRouterFailover(t *testing.T) {
)
}
}
func TestEnableDisableAutoApprovedRoute(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
expectedRoutes := "172.0.0.0/24"
user := "enable-disable-routing"
scenario, err := NewScenario()
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
spec := map[string]int{
user: 1,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:approve"})}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy(
&policy.ACLPolicy{
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
TagOwners: map[string][]string{
"tag:approve": {user},
},
AutoApprovers: policy.AutoApprovers{
Routes: map[string][]string{
expectedRoutes: {"tag:approve"},
},
},
},
))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
subRouter1 := allClients[0]
// Initially advertise route
command := []string{
"tailscale",
"set",
"--advertise-routes=" + expectedRoutes,
}
_, _, err = subRouter1.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
time.Sleep(10 * time.Second)
var routes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routes,
)
assertNoErr(t, err)
assert.Len(t, routes, 1)
// All routes should be auto approved and enabled
assert.Equal(t, true, routes[0].GetAdvertised())
assert.Equal(t, true, routes[0].GetEnabled())
assert.Equal(t, true, routes[0].GetIsPrimary())
// Stop advertising route
command = []string{
"tailscale",
"set",
"--advertise-routes=",
}
_, _, err = subRouter1.Execute(command)
assertNoErrf(t, "failed to remove advertised route: %s", err)
time.Sleep(10 * time.Second)
var notAdvertisedRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&notAdvertisedRoutes,
)
assertNoErr(t, err)
assert.Len(t, notAdvertisedRoutes, 1)
// Route is no longer advertised
assert.Equal(t, false, notAdvertisedRoutes[0].GetAdvertised())
assert.Equal(t, false, notAdvertisedRoutes[0].GetEnabled())
assert.Equal(t, true, notAdvertisedRoutes[0].GetIsPrimary())
// Advertise route again
command = []string{
"tailscale",
"set",
"--advertise-routes=" + expectedRoutes,
}
_, _, err = subRouter1.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
time.Sleep(10 * time.Second)
var reAdvertisedRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&reAdvertisedRoutes,
)
assertNoErr(t, err)
assert.Len(t, reAdvertisedRoutes, 1)
// All routes should be auto approved and enabled
assert.Equal(t, true, reAdvertisedRoutes[0].GetAdvertised())
assert.Equal(t, true, reAdvertisedRoutes[0].GetEnabled())
assert.Equal(t, true, reAdvertisedRoutes[0].GetIsPrimary())
}
// TestSubnetRouteACL verifies that Subnet routes are distributed
// as expected when ACLs are activated.
// It implements the issue from
// https://github.com/juanfont/headscale/issues/1604
func TestSubnetRouteACL(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "subnet-route-acl"
scenario, err := NewScenario()
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
spec := map[string]int{
user: 2,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy(
&policy.ACLPolicy{
Groups: policy.Groups{
"group:admins": {user},
},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"group:admins"},
Destinations: []string{"group:admins:*"},
},
{
Action: "accept",
Sources: []string{"group:admins"},
Destinations: []string{"10.33.0.0/16:*"},
},
// {
// Action: "accept",
// Sources: []string{"group:admins"},
// Destinations: []string{"0.0.0.0/0:*"},
// },
},
},
))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
expectedRoutes := map[string]string{
"1": "10.33.0.0/16",
}
// Sort nodes by ID
sort.SliceStable(allClients, func(i, j int) bool {
statusI, err := allClients[i].Status()
if err != nil {
return false
}
statusJ, err := allClients[j].Status()
if err != nil {
return false
}
return statusI.Self.ID < statusJ.Self.ID
})
subRouter1 := allClients[0]
client := allClients[1]
// advertise HA route on node 1 and 2
// ID 1 will be primary
// ID 2 will be secondary
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
if route, ok := expectedRoutes[string(status.Self.ID)]; ok {
command := []string{
"tailscale",
"set",
"--advertise-routes=" + route,
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
}
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
var routes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routes,
)
assertNoErr(t, err)
assert.Len(t, routes, 1)
for _, route := range routes {
assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary())
}
// Verify that no routes has been sent to the client,
// they are not yet enabled.
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
assert.Nil(t, peerStatus.PrimaryRoutes)
}
}
// Enable all routes
for _, route := range routes {
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
strconv.Itoa(int(route.GetId())),
})
assertNoErr(t, err)
}
time.Sleep(5 * time.Second)
var enablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&enablingRoutes,
)
assertNoErr(t, err)
assert.Len(t, enablingRoutes, 1)
// Node 1 has active route
assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
assert.Equal(t, true, enablingRoutes[0].GetEnabled())
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary())
// Verify that the client has routes from the primary machine
srs1, _ := subRouter1.Status()
clientStatus, err := client.Status()
assertNoErr(t, err)
srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
t.Logf("subnet1 has following routes: %v", srs1PeerStatus.PrimaryRoutes.AsSlice())
assert.Len(t, srs1PeerStatus.PrimaryRoutes.AsSlice(), 1)
assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
)
clientNm, err := client.Netmap()
assertNoErr(t, err)
wantClientFilter := []filter.Match{
{
IPProto: []ipproto.Proto{
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
},
Srcs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
netip.MustParsePrefix("100.64.0.2/32"),
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
},
Dsts: []filter.NetPortRange{
{
Net: netip.MustParsePrefix("100.64.0.2/32"),
Ports: filter.PortRange{0, 0xffff},
},
{
Net: netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
Ports: filter.PortRange{0, 0xffff},
},
},
Caps: []filter.CapMatch{},
},
}
if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.PrefixComparer); diff != "" {
t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)
}
subnetNm, err := subRouter1.Netmap()
assertNoErr(t, err)
wantSubnetFilter := []filter.Match{
{
IPProto: []ipproto.Proto{
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
},
Srcs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
netip.MustParsePrefix("100.64.0.2/32"),
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
},
Dsts: []filter.NetPortRange{
{
Net: netip.MustParsePrefix("100.64.0.1/32"),
Ports: filter.PortRange{0, 0xffff},
},
{
Net: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
Ports: filter.PortRange{0, 0xffff},
},
},
Caps: []filter.CapMatch{},
},
{
IPProto: []ipproto.Proto{
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
},
Srcs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
netip.MustParsePrefix("100.64.0.2/32"),
netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
},
Dsts: []filter.NetPortRange{
{
Net: netip.MustParsePrefix("10.33.0.0/16"),
Ports: filter.PortRange{0, 0xffff},
},
},
Caps: []filter.CapMatch{},
},
}
if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.PrefixComparer); diff != "" {
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
}
}

View file

@ -56,8 +56,8 @@ var (
"1.44": true, // CapVer: 63
"1.42": true, // CapVer: 61
"1.40": true, // CapVer: 61
"1.38": true, // CapVer: 58
"1.36": true, // Oldest supported version, CapVer: 56
"1.38": true, // Oldest supported version, CapVer: 58
"1.36": false, // CapVer: 56
"1.34": false, // CapVer: 51
"1.32": false, // CapVer: 46
"1.30": false,

View file

@ -7,6 +7,8 @@ import (
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/tsic"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/netcheck"
"tailscale.com/types/netmap"
)
// nolint
@ -26,6 +28,8 @@ type TailscaleClient interface {
IPs() ([]netip.Addr, error)
FQDN() (string, error)
Status() (*ipnstate.Status, error)
Netmap() (*netmap.NetworkMap, error)
Netcheck() (*netcheck.Report, error)
WaitForNeedsLogin() error
WaitForRunning() error
WaitForPeers(expected int) error

View file

@ -17,6 +17,8 @@ import (
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/netcheck"
"tailscale.com/types/netmap"
)
const (
@ -519,6 +521,53 @@ func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) {
return &status, err
}
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
// Only works with Tailscale 1.56.1 and newer.
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
command := []string{
"tailscale",
"debug",
"netmap",
}
result, stderr, err := t.Execute(command)
if err != nil {
fmt.Printf("stderr: %s\n", stderr)
return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err)
}
var nm netmap.NetworkMap
err = json.Unmarshal([]byte(result), &nm)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err)
}
return &nm, err
}
// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance.
func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) {
command := []string{
"tailscale",
"netcheck",
"--format=json",
}
result, stderr, err := t.Execute(command)
if err != nil {
fmt.Printf("stderr: %s\n", stderr)
return nil, fmt.Errorf("failed to execute tailscale debug netcheck command: %w", err)
}
var nm netcheck.Report
err = json.Unmarshal([]byte(result), &nm)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal tailscale netcheck: %w", err)
}
return &nm, err
}
// FQDN returns the FQDN as a string of the Tailscale instance.
func (t *TailscaleInContainer) FQDN() (string, error) {
if t.fqdn != "" {
@ -623,12 +672,22 @@ func (t *TailscaleInContainer) WaitForPeers(expected int) error {
len(peers),
)
} else {
// Verify that the peers of a given node is Online
// has a hostname and a DERP relay.
for _, peerKey := range peers {
peer := status.Peer[peerKey]
if !peer.Online {
return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)
}
if peer.HostName == "" {
return fmt.Errorf("[%s] peer count correct, but %s does not have a Hostname", t.hostname, peer.HostName)
}
if peer.Relay == "" {
return fmt.Errorf("[%s] peer count correct, but %s does not have a DERP", t.hostname, peer.HostName)
}
}
}

View file

@ -7,6 +7,8 @@ import (
"time"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"tailscale.com/util/cmpver"
)
const (
@ -83,7 +85,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts
for _, addr := range addrs {
err := client.Ping(addr, opts...)
if err != nil {
t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
} else {
success++
}
@ -120,6 +122,148 @@ func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string)
return success
}
// assertClientsState validates the status and netmap of a list of
// clients for the general case of all to all connectivity.
func assertClientsState(t *testing.T, clients []TailscaleClient) {
t.Helper()
for _, client := range clients {
assertValidStatus(t, client)
assertValidNetmap(t, client)
assertValidNetcheck(t, client)
}
}
// assertValidNetmap asserts that the netmap of a client has all
// the minimum required fields set to a known working config for
// the general case. Fields are checked on self, then all peers.
// This test is not suitable for ACL/partial connection tests.
// This test can only be run on clients from 1.56.1. It will
// automatically pass all clients below that and is safe to call
// for all versions.
func assertValidNetmap(t *testing.T, client TailscaleClient) {
t.Helper()
if cmpver.Compare("1.56.1", client.Version()) <= 0 ||
!strings.Contains(client.Hostname(), "unstable") ||
!strings.Contains(client.Hostname(), "head") {
return
}
netmap, err := client.Netmap()
if err != nil {
t.Fatalf("getting netmap for %q: %s", client.Hostname(), err)
}
assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
}
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
for _, peer := range netmap.Peers {
assert.NotEqualf(t, "127.3.3.40:0", peer.DERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.DERP())
assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
if hi := peer.Hostinfo(); hi.Valid() {
assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
// Netinfo is not always set
assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
if ni := hi.NetInfo(); ni.Valid() {
assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
}
}
assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
assert.Truef(t, *peer.Online(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
}
}
// assertValidStatus asserts that the status of a client has all
// the minimum required fields set to a known working config for
// the general case. Fields are checked on self, then all peers.
// This test is not suitable for ACL/partial connection tests.
func assertValidStatus(t *testing.T, client TailscaleClient) {
t.Helper()
status, err := client.Status()
if err != nil {
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
}
assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname())
// This seem to not appear until version 1.56
if status.Self.AllowedIPs != nil {
assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname())
}
assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname())
assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname())
assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname())
// This isnt really relevant for Self as it wont be in its own socket/wireguard.
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
// assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname())
for _, peer := range status.Peer {
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname())
// This seem to not appear until version 1.56
if peer.AllowedIPs != nil {
assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname())
}
// Addrs does not seem to appear in the status from peers.
// assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname())
assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname())
assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname())
assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname())
// TODO(kradalby): InEngine is only true when a proper tunnel is set up,
// there might be some interesting stuff to test here in the future.
// assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname())
}
}
func assertValidNetcheck(t *testing.T, client TailscaleClient) {
t.Helper()
report, err := client.Netcheck()
if err != nil {
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
}
assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname())
}
func isSelfClient(client TailscaleClient, addr string) bool {
if addr == client.Hostname() {
return true
@ -152,7 +296,7 @@ func isCI() bool {
}
func dockertestMaxWait() time.Duration {
wait := 60 * time.Second //nolint
wait := 120 * time.Second //nolint
if isCI() {
wait = 300 * time.Second //nolint