Rework map session
This commit restructures the map session in to a struct holding the state of what is needed during its lifetime. For streaming sessions, the event loop is structured a bit differently not hammering the clients with updates but rather batching them over a short, configurable time which should significantly improve cpu usage, and potentially flakyness. The use of Patch updates has been dialed back a little as it does not look like its a 100% ready for prime time. Nodes are now updated with full changes, except for a few things like online status. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
dd693c444c
commit
58c94d2bd3
35 changed files with 1803 additions and 1716 deletions
3
.github/workflows/test-integration.yaml
vendored
3
.github/workflows/test-integration.yaml
vendored
|
@ -43,7 +43,8 @@ jobs:
|
||||||
- TestTaildrop
|
- TestTaildrop
|
||||||
- TestResolveMagicDNS
|
- TestResolveMagicDNS
|
||||||
- TestExpireNode
|
- TestExpireNode
|
||||||
- TestNodeOnlineLastSeenStatus
|
- TestNodeOnlineStatus
|
||||||
|
- TestPingAllByIPManyUpDown
|
||||||
- TestEnablingRoutes
|
- TestEnablingRoutes
|
||||||
- TestHASubnetRouterFailover
|
- TestHASubnetRouterFailover
|
||||||
- TestEnableDisableAutoApprovedRoute
|
- TestEnableDisableAutoApprovedRoute
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -150,6 +150,7 @@ require (
|
||||||
github.com/opencontainers/image-spec v1.1.0-rc6 // indirect
|
github.com/opencontainers/image-spec v1.1.0-rc6 // indirect
|
||||||
github.com/opencontainers/runc v1.1.12 // indirect
|
github.com/opencontainers/runc v1.1.12 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.1.1 // indirect
|
github.com/pelletier/go-toml/v2 v2.1.1 // indirect
|
||||||
|
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect
|
||||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
|
@ -161,6 +162,7 @@ require (
|
||||||
github.com/safchain/ethtool v0.3.0 // indirect
|
github.com/safchain/ethtool v0.3.0 // indirect
|
||||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||||
|
github.com/sasha-s/go-deadlock v0.3.1 // indirect
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||||
github.com/spf13/afero v1.11.0 // indirect
|
github.com/spf13/afero v1.11.0 // indirect
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -336,6 +336,8 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||||
github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI=
|
github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI=
|
||||||
github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||||
|
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ=
|
||||||
|
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o=
|
||||||
github.com/philip-bui/grpc-zerolog v1.0.1 h1:EMacvLRUd2O1K0eWod27ZP5CY1iTNkhBDLSN+Q4JEvA=
|
github.com/philip-bui/grpc-zerolog v1.0.1 h1:EMacvLRUd2O1K0eWod27ZP5CY1iTNkhBDLSN+Q4JEvA=
|
||||||
github.com/philip-bui/grpc-zerolog v1.0.1/go.mod h1:qXbiq/2X4ZUMMshsqlWyTHOcw7ns+GZmlqZZN05ZHcQ=
|
github.com/philip-bui/grpc-zerolog v1.0.1/go.mod h1:qXbiq/2X4ZUMMshsqlWyTHOcw7ns+GZmlqZZN05ZHcQ=
|
||||||
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||||
|
@ -392,6 +394,8 @@ github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6g
|
||||||
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
|
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
|
||||||
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
||||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||||
|
github.com/sasha-s/go-deadlock v0.3.1 h1:sqv7fDNShgjcaxkO0JNcOAlr8B9+cV5Ey/OB71efZx0=
|
||||||
|
github.com/sasha-s/go-deadlock v0.3.1/go.mod h1:F73l+cr82YSh10GxyRI6qZiCgK64VaZjwesgfQ1/iLM=
|
||||||
github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
|
github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
|
||||||
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
|
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
|
||||||
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
|
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
|
||||||
|
|
|
@ -28,6 +28,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/derp"
|
"github.com/juanfont/headscale/hscontrol/derp"
|
||||||
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
@ -38,6 +39,7 @@ import (
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
zl "github.com/rs/zerolog"
|
zl "github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/sasha-s/go-deadlock"
|
||||||
"golang.org/x/crypto/acme"
|
"golang.org/x/crypto/acme"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
@ -77,6 +79,11 @@ const (
|
||||||
registerCacheCleanup = time.Minute * 20
|
registerCacheCleanup = time.Minute * 20
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
deadlock.Opts.DeadlockTimeout = 15 * time.Second
|
||||||
|
deadlock.Opts.PrintAllCurrentGoroutines = true
|
||||||
|
}
|
||||||
|
|
||||||
// Headscale represents the base app of the service.
|
// Headscale represents the base app of the service.
|
||||||
type Headscale struct {
|
type Headscale struct {
|
||||||
cfg *types.Config
|
cfg *types.Config
|
||||||
|
@ -89,6 +96,7 @@ type Headscale struct {
|
||||||
|
|
||||||
ACLPolicy *policy.ACLPolicy
|
ACLPolicy *policy.ACLPolicy
|
||||||
|
|
||||||
|
mapper *mapper.Mapper
|
||||||
nodeNotifier *notifier.Notifier
|
nodeNotifier *notifier.Notifier
|
||||||
|
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
|
@ -96,8 +104,10 @@ type Headscale struct {
|
||||||
|
|
||||||
registrationCache *cache.Cache
|
registrationCache *cache.Cache
|
||||||
|
|
||||||
shutdownChan chan struct{}
|
|
||||||
pollNetMapStreamWG sync.WaitGroup
|
pollNetMapStreamWG sync.WaitGroup
|
||||||
|
|
||||||
|
mapSessions map[types.NodeID]*mapSession
|
||||||
|
mapSessionMu deadlock.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -129,6 +139,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
registrationCache: registrationCache,
|
registrationCache: registrationCache,
|
||||||
pollNetMapStreamWG: sync.WaitGroup{},
|
pollNetMapStreamWG: sync.WaitGroup{},
|
||||||
nodeNotifier: notifier.NewNotifier(),
|
nodeNotifier: notifier.NewNotifier(),
|
||||||
|
mapSessions: make(map[types.NodeID]*mapSession),
|
||||||
}
|
}
|
||||||
|
|
||||||
app.db, err = db.NewHeadscaleDatabase(
|
app.db, err = db.NewHeadscaleDatabase(
|
||||||
|
@ -199,16 +210,16 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
|
||||||
http.Redirect(w, req, target, http.StatusFound)
|
http.Redirect(w, req, target, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// expireEphemeralNodes deletes ephemeral node records that have not been
|
// deleteExpireEphemeralNodes deletes ephemeral node records that have not been
|
||||||
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
|
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
|
||||||
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) {
|
||||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||||
|
|
||||||
var update types.StateUpdate
|
|
||||||
var changed bool
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
|
var removed []types.NodeID
|
||||||
|
var changed []types.NodeID
|
||||||
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
|
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
@ -216,9 +227,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if changed && update.Valid() {
|
if removed != nil {
|
||||||
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
|
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
|
||||||
h.nodeNotifier.NotifyAll(ctx, update)
|
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
|
Type: types.StatePeerRemoved,
|
||||||
|
Removed: removed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if changed != nil {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
|
||||||
|
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
ChangeNodes: changed,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -243,8 +265,9 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes")
|
if changed {
|
||||||
if changed && update.Valid() {
|
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
|
||||||
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
|
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
|
||||||
h.nodeNotifier.NotifyAll(ctx, update)
|
h.nodeNotifier.NotifyAll(ctx, update)
|
||||||
}
|
}
|
||||||
|
@ -272,14 +295,11 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
|
||||||
h.DERPMap.Regions[region.RegionID] = ®ion
|
h.DERPMap.Regions[region.RegionID] = ®ion
|
||||||
}
|
}
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
|
||||||
|
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
Type: types.StateDERPUpdated,
|
Type: types.StateDERPUpdated,
|
||||||
DERPMap: h.DERPMap,
|
DERPMap: h.DERPMap,
|
||||||
}
|
})
|
||||||
if stateUpdate.Valid() {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
|
|
||||||
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -502,6 +522,7 @@ func (h *Headscale) Serve() error {
|
||||||
|
|
||||||
// Fetch an initial DERP Map before we start serving
|
// Fetch an initial DERP Map before we start serving
|
||||||
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
||||||
|
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier.ConnectedMap())
|
||||||
|
|
||||||
if h.cfg.DERP.ServerEnabled {
|
if h.cfg.DERP.ServerEnabled {
|
||||||
// When embedded DERP is enabled we always need a STUN server
|
// When embedded DERP is enabled we always need a STUN server
|
||||||
|
@ -533,7 +554,7 @@ func (h *Headscale) Serve() error {
|
||||||
|
|
||||||
// TODO(kradalby): These should have cancel channels and be cleaned
|
// TODO(kradalby): These should have cancel channels and be cleaned
|
||||||
// up on shutdown.
|
// up on shutdown.
|
||||||
go h.expireEphemeralNodes(updateInterval)
|
go h.deleteExpireEphemeralNodes(updateInterval)
|
||||||
go h.expireExpiredMachines(updateInterval)
|
go h.expireExpiredMachines(updateInterval)
|
||||||
|
|
||||||
if zl.GlobalLevel() == zl.TraceLevel {
|
if zl.GlobalLevel() == zl.TraceLevel {
|
||||||
|
@ -686,6 +707,9 @@ func (h *Headscale) Serve() error {
|
||||||
// no good way to handle streaming timeouts, therefore we need to
|
// no good way to handle streaming timeouts, therefore we need to
|
||||||
// keep this at unlimited and be careful to clean up connections
|
// keep this at unlimited and be careful to clean up connections
|
||||||
// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming
|
// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming
|
||||||
|
// TODO(kradalby): this timeout can now be set per handler with http.ResponseController:
|
||||||
|
// https://www.alexedwards.net/blog/how-to-use-the-http-responsecontroller-type
|
||||||
|
// replace this so only the longpoller has no timeout.
|
||||||
WriteTimeout: 0,
|
WriteTimeout: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -742,7 +766,6 @@ func (h *Headscale) Serve() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle common process-killing signals so we can gracefully shut down:
|
// Handle common process-killing signals so we can gracefully shut down:
|
||||||
h.shutdownChan = make(chan struct{})
|
|
||||||
sigc := make(chan os.Signal, 1)
|
sigc := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigc,
|
signal.Notify(sigc,
|
||||||
syscall.SIGHUP,
|
syscall.SIGHUP,
|
||||||
|
@ -785,8 +808,6 @@ func (h *Headscale) Serve() error {
|
||||||
Str("signal", sig.String()).
|
Str("signal", sig.String()).
|
||||||
Msg("Received signal to stop, shutting down gracefully")
|
Msg("Received signal to stop, shutting down gracefully")
|
||||||
|
|
||||||
close(h.shutdownChan)
|
|
||||||
|
|
||||||
h.pollNetMapStreamWG.Wait()
|
h.pollNetMapStreamWG.Wait()
|
||||||
|
|
||||||
// Gracefully shut down servers
|
// Gracefully shut down servers
|
||||||
|
|
|
@ -352,13 +352,8 @@ func (h *Headscale) handleAuthKey(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mkey := node.MachineKey
|
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
|
||||||
update := types.StateUpdateExpire(node.ID, registerRequest.Expiry)
|
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, registerRequest.Expiry), node.ID)
|
||||||
|
|
||||||
if update.Valid() {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
|
|
||||||
h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String())
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
@ -538,11 +533,8 @@ func (h *Headscale) handleNodeLogOut(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stateUpdate := types.StateUpdateExpire(node.ID, now)
|
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
|
||||||
if stateUpdate.Valid() {
|
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID)
|
||||||
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
|
|
||||||
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.MachineAuthorized = false
|
resp.MachineAuthorized = false
|
||||||
|
@ -572,7 +564,7 @@ func (h *Headscale) handleNodeLogOut(
|
||||||
}
|
}
|
||||||
|
|
||||||
if node.IsEphemeral() {
|
if node.IsEphemeral() {
|
||||||
err = h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap())
|
changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -580,13 +572,16 @@ func (h *Headscale) handleNodeLogOut(
|
||||||
Msg("Cannot delete ephemeral node from the database")
|
Msg("Cannot delete ephemeral node from the database")
|
||||||
}
|
}
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
|
||||||
|
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
Type: types.StatePeerRemoved,
|
Type: types.StatePeerRemoved,
|
||||||
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
|
Removed: []types.NodeID{node.ID},
|
||||||
}
|
})
|
||||||
if stateUpdate.Valid() {
|
if changedNodes != nil {
|
||||||
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
|
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
|
Type: types.StatePeerChanged,
|
||||||
|
ChangeNodes: changedNodes,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
|
@ -34,27 +34,22 @@ var (
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) {
|
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
return ListPeers(rx, node)
|
return ListPeers(rx, nodeID)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
|
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
|
||||||
func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) {
|
func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Msg("Finding direct peers")
|
|
||||||
|
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := tx.
|
if err := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
Preload("Routes").
|
Preload("Routes").
|
||||||
Where("node_key <> ?",
|
Where("id <> ?",
|
||||||
node.NodeKey.String()).Find(&nodes).Error; err != nil {
|
nodeID).Find(&nodes).Error; err != nil {
|
||||||
return types.Nodes{}, err
|
return types.Nodes{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,14 +114,14 @@ func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
|
||||||
return nil, ErrNodeNotFound
|
return nil, ErrNodeNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
|
func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||||
return GetNodeByID(rx, id)
|
return GetNodeByID(rx, id)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodeByID finds a Node by ID and returns the Node struct.
|
// GetNodeByID finds a Node by ID and returns the Node struct.
|
||||||
func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) {
|
func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
|
||||||
mach := types.Node{}
|
mach := types.Node{}
|
||||||
if result := tx.
|
if result := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
|
@ -197,7 +192,7 @@ func GetNodeByAnyKey(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) SetTags(
|
func (hsdb *HSDatabase) SetTags(
|
||||||
nodeID uint64,
|
nodeID types.NodeID,
|
||||||
tags []string,
|
tags []string,
|
||||||
) error {
|
) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
|
@ -208,7 +203,7 @@ func (hsdb *HSDatabase) SetTags(
|
||||||
// SetTags takes a Node struct pointer and update the forced tags.
|
// SetTags takes a Node struct pointer and update the forced tags.
|
||||||
func SetTags(
|
func SetTags(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
nodeID uint64,
|
nodeID types.NodeID,
|
||||||
tags []string,
|
tags []string,
|
||||||
) error {
|
) error {
|
||||||
if len(tags) == 0 {
|
if len(tags) == 0 {
|
||||||
|
@ -256,7 +251,7 @@ func RenameNode(tx *gorm.DB,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error {
|
func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return NodeSetExpiry(tx, nodeID, expiry)
|
return NodeSetExpiry(tx, nodeID, expiry)
|
||||||
})
|
})
|
||||||
|
@ -264,13 +259,13 @@ func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error {
|
||||||
|
|
||||||
// NodeSetExpiry takes a Node struct and a new expiry time.
|
// NodeSetExpiry takes a Node struct and a new expiry time.
|
||||||
func NodeSetExpiry(tx *gorm.DB,
|
func NodeSetExpiry(tx *gorm.DB,
|
||||||
nodeID uint64, expiry time.Time,
|
nodeID types.NodeID, expiry time.Time,
|
||||||
) error {
|
) error {
|
||||||
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
|
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.MachinePublic]bool) error {
|
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||||
return DeleteNode(tx, node, isConnected)
|
return DeleteNode(tx, node, isConnected)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -279,24 +274,24 @@ func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.Machine
|
||||||
// Caller is responsible for notifying all of change.
|
// Caller is responsible for notifying all of change.
|
||||||
func DeleteNode(tx *gorm.DB,
|
func DeleteNode(tx *gorm.DB,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
isConnected map[key.MachinePublic]bool,
|
isConnected types.NodeConnectedMap,
|
||||||
) error {
|
) ([]types.NodeID, error) {
|
||||||
err := deleteNodeRoutes(tx, node, map[key.MachinePublic]bool{})
|
changed, err := deleteNodeRoutes(tx, node, isConnected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return changed, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unscoped causes the node to be fully removed from the database.
|
// Unscoped causes the node to be fully removed from the database.
|
||||||
if err := tx.Unscoped().Delete(&node).Error; err != nil {
|
if err := tx.Unscoped().Delete(&node).Error; err != nil {
|
||||||
return err
|
return changed, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return changed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLastSeen sets a node's last seen field indicating that we
|
// SetLastSeen sets a node's last seen field indicating that we
|
||||||
// have recently communicating with this node.
|
// have recently communicating with this node.
|
||||||
func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error {
|
func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
|
||||||
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
|
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -606,7 +601,7 @@ func enableRoutes(tx *gorm.DB,
|
||||||
|
|
||||||
return &types.StateUpdate{
|
return &types.StateUpdate{
|
||||||
Type: types.StatePeerChanged,
|
Type: types.StatePeerChanged,
|
||||||
ChangeNodes: types.Nodes{node},
|
ChangeNodes: []types.NodeID{node.ID},
|
||||||
Message: "created in db.enableRoutes",
|
Message: "created in db.enableRoutes",
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -681,17 +676,18 @@ func GenerateGivenName(
|
||||||
return givenName, nil
|
return givenName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExpireEphemeralNodes(tx *gorm.DB,
|
func DeleteExpiredEphemeralNodes(tx *gorm.DB,
|
||||||
inactivityThreshhold time.Duration,
|
inactivityThreshhold time.Duration,
|
||||||
) (types.StateUpdate, bool) {
|
) ([]types.NodeID, []types.NodeID) {
|
||||||
users, err := ListUsers(tx)
|
users, err := ListUsers(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error listing users")
|
log.Error().Err(err).Msg("Error listing users")
|
||||||
|
|
||||||
return types.StateUpdate{}, false
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
expired := make([]tailcfg.NodeID, 0)
|
var expired []types.NodeID
|
||||||
|
var changedNodes []types.NodeID
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
nodes, err := ListNodesByUser(tx, user.Name)
|
nodes, err := ListNodesByUser(tx, user.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -700,40 +696,36 @@ func ExpireEphemeralNodes(tx *gorm.DB,
|
||||||
Str("user", user.Name).
|
Str("user", user.Name).
|
||||||
Msg("Error listing nodes in user")
|
Msg("Error listing nodes in user")
|
||||||
|
|
||||||
return types.StateUpdate{}, false
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, node := range nodes {
|
for idx, node := range nodes {
|
||||||
if node.IsEphemeral() && node.LastSeen != nil &&
|
if node.IsEphemeral() && node.LastSeen != nil &&
|
||||||
time.Now().
|
time.Now().
|
||||||
After(node.LastSeen.Add(inactivityThreshhold)) {
|
After(node.LastSeen.Add(inactivityThreshhold)) {
|
||||||
expired = append(expired, tailcfg.NodeID(node.ID))
|
expired = append(expired, node.ID)
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Msg("Ephemeral client removed from database")
|
Msg("Ephemeral client removed from database")
|
||||||
|
|
||||||
// empty isConnected map as ephemeral nodes are not routes
|
// empty isConnected map as ephemeral nodes are not routes
|
||||||
err = DeleteNode(tx, nodes[idx], map[key.MachinePublic]bool{})
|
changed, err := DeleteNode(tx, nodes[idx], nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Msg("🤮 Cannot delete ephemeral node from the database")
|
Msg("🤮 Cannot delete ephemeral node from the database")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
changedNodes = append(changedNodes, changed...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): needs to be moved out of transaction
|
// TODO(kradalby): needs to be moved out of transaction
|
||||||
}
|
}
|
||||||
if len(expired) > 0 {
|
|
||||||
return types.StateUpdate{
|
|
||||||
Type: types.StatePeerRemoved,
|
|
||||||
Removed: expired,
|
|
||||||
}, true
|
|
||||||
}
|
|
||||||
|
|
||||||
return types.StateUpdate{}, false
|
return expired, changedNodes
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExpireExpiredNodes(tx *gorm.DB,
|
func ExpireExpiredNodes(tx *gorm.DB,
|
||||||
|
@ -754,35 +746,12 @@ func ExpireExpiredNodes(tx *gorm.DB,
|
||||||
|
|
||||||
return time.Unix(0, 0), types.StateUpdate{}, false
|
return time.Unix(0, 0), types.StateUpdate{}, false
|
||||||
}
|
}
|
||||||
for index, node := range nodes {
|
for _, node := range nodes {
|
||||||
if node.IsExpired() &&
|
if node.IsExpired() && node.Expiry.After(lastCheck) {
|
||||||
// TODO(kradalby): Replace this, it is very spammy
|
|
||||||
// It will notify about all nodes that has been expired.
|
|
||||||
// It should only notify about expired nodes since _last check_.
|
|
||||||
node.Expiry.After(lastCheck) {
|
|
||||||
expired = append(expired, &tailcfg.PeerChange{
|
expired = append(expired, &tailcfg.PeerChange{
|
||||||
NodeID: tailcfg.NodeID(node.ID),
|
NodeID: tailcfg.NodeID(node.ID),
|
||||||
KeyExpiry: node.Expiry,
|
KeyExpiry: node.Expiry,
|
||||||
})
|
})
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
// Do not use setNodeExpiry as that has a notifier hook, which
|
|
||||||
// can cause a deadlock, we are updating all changed nodes later
|
|
||||||
// and there is no point in notifiying twice.
|
|
||||||
if err := tx.Model(&nodes[index]).Updates(types.Node{
|
|
||||||
Expiry: &now,
|
|
||||||
}).Error; err != nil {
|
|
||||||
log.Error().
|
|
||||||
Err(err).
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Str("name", node.GivenName).
|
|
||||||
Msg("🤮 Cannot expire node")
|
|
||||||
} else {
|
|
||||||
log.Info().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Str("name", node.GivenName).
|
|
||||||
Msg("Node successfully expired")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -120,7 +120,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||||
}
|
}
|
||||||
db.DB.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
err = db.DeleteNode(&node, map[key.MachinePublic]bool{})
|
_, err = db.DeleteNode(&node, types.NodeConnectedMap{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(user.Name, "testnode3")
|
_, err = db.getNode(user.Name, "testnode3")
|
||||||
|
@ -142,7 +142,7 @@ func (s *Suite) TestListPeers(c *check.C) {
|
||||||
machineKey := key.NewMachine()
|
machineKey := key.NewMachine()
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: uint64(index),
|
ID: types.NodeID(index),
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey.Public(),
|
NodeKey: nodeKey.Public(),
|
||||||
Hostname: "testnode" + strconv.Itoa(index),
|
Hostname: "testnode" + strconv.Itoa(index),
|
||||||
|
@ -156,7 +156,7 @@ func (s *Suite) TestListPeers(c *check.C) {
|
||||||
node0ByID, err := db.GetNodeByID(0)
|
node0ByID, err := db.GetNodeByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
peersOfNode0, err := db.ListPeers(node0ByID)
|
peersOfNode0, err := db.ListPeers(node0ByID.ID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(peersOfNode0), check.Equals, 9)
|
c.Assert(len(peersOfNode0), check.Equals, 9)
|
||||||
|
@ -189,7 +189,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
machineKey := key.NewMachine()
|
machineKey := key.NewMachine()
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: uint64(index),
|
ID: types.NodeID(index),
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey.Public(),
|
NodeKey: nodeKey.Public(),
|
||||||
IPAddresses: types.NodeAddresses{
|
IPAddresses: types.NodeAddresses{
|
||||||
|
@ -232,16 +232,16 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User)
|
c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
adminPeers, err := db.ListPeers(adminNode)
|
adminPeers, err := db.ListPeers(adminNode.ID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
testPeers, err := db.ListPeers(testNode)
|
testPeers, err := db.ListPeers(testNode.ID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
adminRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, adminNode, adminPeers)
|
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
testRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, testNode, testPeers)
|
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
||||||
|
@ -586,7 +586,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// TODO(kradalby): Check state update
|
// TODO(kradalby): Check state update
|
||||||
_, err = db.EnableAutoApprovedRoutes(pol, node0ByID)
|
err = db.EnableAutoApprovedRoutes(pol, node0ByID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes, err := db.GetEnabledRoutes(node0ByID)
|
enabledRoutes, err := db.GetEnabledRoutes(node0ByID)
|
||||||
|
|
|
@ -92,10 +92,6 @@ func CreatePreAuthKey(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &key, nil
|
return &key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -148,7 +148,7 @@ func (*Suite) TestEphemeralKeyReusable(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
db.DB.Transaction(func(tx *gorm.DB) error {
|
db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
ExpireEphemeralNodes(tx, time.Second*20)
|
DeleteExpiredEphemeralNodes(tx, time.Second*20)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ func (*Suite) TestEphemeralKeyNotReusable(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
db.DB.Transaction(func(tx *gorm.DB) error {
|
db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
ExpireEphemeralNodes(tx, time.Second*20)
|
DeleteExpiredEphemeralNodes(tx, time.Second*20)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/types/key"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||||
|
@ -124,8 +123,8 @@ func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
|
||||||
|
|
||||||
func DisableRoute(tx *gorm.DB,
|
func DisableRoute(tx *gorm.DB,
|
||||||
id uint64,
|
id uint64,
|
||||||
isConnected map[key.MachinePublic]bool,
|
isConnected types.NodeConnectedMap,
|
||||||
) (*types.StateUpdate, error) {
|
) ([]types.NodeID, error) {
|
||||||
route, err := GetRoute(tx, id)
|
route, err := GetRoute(tx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -137,16 +136,15 @@ func DisableRoute(tx *gorm.DB,
|
||||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||||
// be enabled at the same time, as per
|
// be enabled at the same time, as per
|
||||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||||
var update *types.StateUpdate
|
var update []types.NodeID
|
||||||
if !route.IsExitRoute() {
|
if !route.IsExitRoute() {
|
||||||
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
|
route.Enabled = false
|
||||||
|
err = tx.Save(route).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
route.Enabled = false
|
update, err = failoverRouteTx(tx, isConnected, route)
|
||||||
route.IsPrimary = false
|
|
||||||
err = tx.Save(route).Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -160,6 +158,7 @@ func DisableRoute(tx *gorm.DB,
|
||||||
if routes[i].IsExitRoute() {
|
if routes[i].IsExitRoute() {
|
||||||
routes[i].Enabled = false
|
routes[i].Enabled = false
|
||||||
routes[i].IsPrimary = false
|
routes[i].IsPrimary = false
|
||||||
|
|
||||||
err = tx.Save(&routes[i]).Error
|
err = tx.Save(&routes[i]).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -168,26 +167,11 @@ func DisableRoute(tx *gorm.DB,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if routes == nil {
|
|
||||||
routes, err = GetNodeRoutes(tx, &node)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
node.Routes = routes
|
|
||||||
|
|
||||||
// If update is empty, it means that one was not created
|
// If update is empty, it means that one was not created
|
||||||
// by failover (as a failover was not necessary), create
|
// by failover (as a failover was not necessary), create
|
||||||
// one and return to the caller.
|
// one and return to the caller.
|
||||||
if update == nil {
|
if update == nil {
|
||||||
update = &types.StateUpdate{
|
update = []types.NodeID{node.ID}
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: types.Nodes{
|
|
||||||
&node,
|
|
||||||
},
|
|
||||||
Message: "called from db.DisableRoute",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return update, nil
|
return update, nil
|
||||||
|
@ -195,9 +179,9 @@ func DisableRoute(tx *gorm.DB,
|
||||||
|
|
||||||
func (hsdb *HSDatabase) DeleteRoute(
|
func (hsdb *HSDatabase) DeleteRoute(
|
||||||
id uint64,
|
id uint64,
|
||||||
isConnected map[key.MachinePublic]bool,
|
isConnected types.NodeConnectedMap,
|
||||||
) (*types.StateUpdate, error) {
|
) ([]types.NodeID, error) {
|
||||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||||
return DeleteRoute(tx, id, isConnected)
|
return DeleteRoute(tx, id, isConnected)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -205,8 +189,8 @@ func (hsdb *HSDatabase) DeleteRoute(
|
||||||
func DeleteRoute(
|
func DeleteRoute(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
id uint64,
|
id uint64,
|
||||||
isConnected map[key.MachinePublic]bool,
|
isConnected types.NodeConnectedMap,
|
||||||
) (*types.StateUpdate, error) {
|
) ([]types.NodeID, error) {
|
||||||
route, err := GetRoute(tx, id)
|
route, err := GetRoute(tx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -218,9 +202,9 @@ func DeleteRoute(
|
||||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||||
// be enabled at the same time, as per
|
// be enabled at the same time, as per
|
||||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||||
var update *types.StateUpdate
|
var update []types.NodeID
|
||||||
if !route.IsExitRoute() {
|
if !route.IsExitRoute() {
|
||||||
update, err = failoverRouteReturnUpdate(tx, isConnected, route)
|
update, err = failoverRouteTx(tx, isConnected, route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -229,7 +213,7 @@ func DeleteRoute(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
routes, err := GetNodeRoutes(tx, &node)
|
routes, err = GetNodeRoutes(tx, &node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -259,35 +243,37 @@ func DeleteRoute(
|
||||||
node.Routes = routes
|
node.Routes = routes
|
||||||
|
|
||||||
if update == nil {
|
if update == nil {
|
||||||
update = &types.StateUpdate{
|
update = []types.NodeID{node.ID}
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: types.Nodes{
|
|
||||||
&node,
|
|
||||||
},
|
|
||||||
Message: "called from db.DeleteRoute",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return update, nil
|
return update, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error {
|
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) {
|
||||||
routes, err := GetNodeRoutes(tx, node)
|
routes, err := GetNodeRoutes(tx, node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var changed []types.NodeID
|
||||||
for i := range routes {
|
for i := range routes {
|
||||||
if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
|
if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): This is a bit too aggressive, we could probably
|
// TODO(kradalby): This is a bit too aggressive, we could probably
|
||||||
// figure out which routes needs to be failed over rather than all.
|
// figure out which routes needs to be failed over rather than all.
|
||||||
failoverRouteReturnUpdate(tx, isConnected, &routes[i])
|
chn, err := failoverRouteTx(tx, isConnected, &routes[i])
|
||||||
|
if err != nil {
|
||||||
|
return changed, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if chn != nil {
|
||||||
|
changed = append(changed, chn...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return changed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isUniquePrefix returns if there is another node providing the same route already.
|
// isUniquePrefix returns if there is another node providing the same route already.
|
||||||
|
@ -400,7 +386,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
|
||||||
for prefix, exists := range advertisedRoutes {
|
for prefix, exists := range advertisedRoutes {
|
||||||
if !exists {
|
if !exists {
|
||||||
route := types.Route{
|
route := types.Route{
|
||||||
NodeID: node.ID,
|
NodeID: node.ID.Uint64(),
|
||||||
Prefix: types.IPPrefix(prefix),
|
Prefix: types.IPPrefix(prefix),
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
|
@ -415,19 +401,23 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
|
||||||
return sendUpdate, nil
|
return sendUpdate, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route
|
// FailoverRouteIfAvailable takes a node and checks if the node's route
|
||||||
// currently have a functioning host that exposes the network.
|
// currently have a functioning host that exposes the network.
|
||||||
func EnsureFailoverRouteIsAvailable(
|
// If it does not, it is failed over to another suitable route if there
|
||||||
|
// is one.
|
||||||
|
func FailoverRouteIfAvailable(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
isConnected map[key.MachinePublic]bool,
|
isConnected types.NodeConnectedMap,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
) (*types.StateUpdate, error) {
|
) (*types.StateUpdate, error) {
|
||||||
|
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Msgf("ROUTE DEBUG ENTERED FAILOVER")
|
||||||
nodeRoutes, err := GetNodeRoutes(tx, node)
|
nodeRoutes, err := GetNodeRoutes(tx, node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Interface("nodeRoutes", nodeRoutes).Msgf("ROUTE DEBUG NO ROUTES")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var changedNodes types.Nodes
|
var changedNodes []types.NodeID
|
||||||
for _, nodeRoute := range nodeRoutes {
|
for _, nodeRoute := range nodeRoutes {
|
||||||
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
|
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -438,71 +428,39 @@ func EnsureFailoverRouteIsAvailable(
|
||||||
if route.IsPrimary {
|
if route.IsPrimary {
|
||||||
// if we have a primary route, and the node is connected
|
// if we have a primary route, and the node is connected
|
||||||
// nothing needs to be done.
|
// nothing needs to be done.
|
||||||
if isConnected[route.Node.MachineKey] {
|
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG CHECKING IF ONLINE")
|
||||||
continue
|
if isConnected[route.Node.ID] {
|
||||||
|
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG IS ONLINE")
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG NOT ONLINE, FAILING OVER")
|
||||||
// if not, we need to failover the route
|
// if not, we need to failover the route
|
||||||
update, err := failoverRouteReturnUpdate(tx, isConnected, &route)
|
changedIDs, err := failoverRouteTx(tx, isConnected, &route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if update != nil {
|
if changedIDs != nil {
|
||||||
changedNodes = append(changedNodes, update.ChangeNodes...)
|
changedNodes = append(changedNodes, changedIDs...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Interface("changedNodes", changedNodes).Msgf("ROUTE DEBUG")
|
||||||
if len(changedNodes) != 0 {
|
if len(changedNodes) != 0 {
|
||||||
return &types.StateUpdate{
|
return &types.StateUpdate{
|
||||||
Type: types.StatePeerChanged,
|
Type: types.StatePeerChanged,
|
||||||
ChangeNodes: changedNodes,
|
ChangeNodes: changedNodes,
|
||||||
Message: "called from db.EnsureFailoverRouteIsAvailable",
|
Message: "called from db.FailoverRouteIfAvailable",
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func failoverRouteReturnUpdate(
|
// failoverRouteTx takes a route that is no longer available,
|
||||||
tx *gorm.DB,
|
|
||||||
isConnected map[key.MachinePublic]bool,
|
|
||||||
r *types.Route,
|
|
||||||
) (*types.StateUpdate, error) {
|
|
||||||
changedKeys, err := failoverRoute(tx, isConnected, r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Interface("isConnected", isConnected).
|
|
||||||
Interface("changedKeys", changedKeys).
|
|
||||||
Msg("building route failover")
|
|
||||||
|
|
||||||
if len(changedKeys) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var nodes types.Nodes
|
|
||||||
for _, key := range changedKeys {
|
|
||||||
node, err := GetNodeByMachineKey(tx, key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
nodes = append(nodes, node)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &types.StateUpdate{
|
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: nodes,
|
|
||||||
Message: "called from db.failoverRouteReturnUpdate",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// failoverRoute takes a route that is no longer available,
|
|
||||||
// this can be either from:
|
// this can be either from:
|
||||||
// - being disabled
|
// - being disabled
|
||||||
// - being deleted
|
// - being deleted
|
||||||
|
@ -510,11 +468,11 @@ func failoverRouteReturnUpdate(
|
||||||
//
|
//
|
||||||
// and tries to find a new route to take over its place.
|
// and tries to find a new route to take over its place.
|
||||||
// If the given route was not primary, it returns early.
|
// If the given route was not primary, it returns early.
|
||||||
func failoverRoute(
|
func failoverRouteTx(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
isConnected map[key.MachinePublic]bool,
|
isConnected types.NodeConnectedMap,
|
||||||
r *types.Route,
|
r *types.Route,
|
||||||
) ([]key.MachinePublic, error) {
|
) ([]types.NodeID, error) {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -535,11 +493,64 @@ func failoverRoute(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fo := failoverRoute(isConnected, r, routes)
|
||||||
|
if fo == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Save(fo.old).Error
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("disabling old primary route")
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Save(fo.new).Error
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("saving new primary route")
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Trace().
|
||||||
|
Str("hostname", fo.new.Node.Hostname).
|
||||||
|
Msgf("set primary to new route, was: id(%d), host(%s), now: id(%d), host(%s)", fo.old.ID, fo.old.Node.Hostname, fo.new.ID, fo.new.Node.Hostname)
|
||||||
|
|
||||||
|
// Return a list of the machinekeys of the changed nodes.
|
||||||
|
return []types.NodeID{fo.old.Node.ID, fo.new.Node.ID}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type failover struct {
|
||||||
|
old *types.Route
|
||||||
|
new *types.Route
|
||||||
|
}
|
||||||
|
|
||||||
|
func failoverRoute(
|
||||||
|
isConnected types.NodeConnectedMap,
|
||||||
|
routeToReplace *types.Route,
|
||||||
|
altRoutes types.Routes,
|
||||||
|
|
||||||
|
) *failover {
|
||||||
|
if routeToReplace == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This route is not a primary route, and it is not
|
||||||
|
// being served to nodes.
|
||||||
|
if !routeToReplace.IsPrimary {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// We do not have to failover exit nodes
|
||||||
|
if routeToReplace.IsExitRoute() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var newPrimary *types.Route
|
var newPrimary *types.Route
|
||||||
|
|
||||||
// Find a new suitable route
|
// Find a new suitable route
|
||||||
for idx, route := range routes {
|
for idx, route := range altRoutes {
|
||||||
if r.ID == route.ID {
|
if routeToReplace.ID == route.ID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -547,8 +558,8 @@ func failoverRoute(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if isConnected[route.Node.MachineKey] {
|
if isConnected != nil && isConnected[route.Node.ID] {
|
||||||
newPrimary = &routes[idx]
|
newPrimary = &altRoutes[idx]
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -559,48 +570,23 @@ func failoverRoute(
|
||||||
// the one currently marked as primary is the
|
// the one currently marked as primary is the
|
||||||
// best we got.
|
// best we got.
|
||||||
if newPrimary == nil {
|
if newPrimary == nil {
|
||||||
return nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
routeToReplace.IsPrimary = false
|
||||||
Str("hostname", newPrimary.Node.Hostname).
|
|
||||||
Msg("found new primary, updating db")
|
|
||||||
|
|
||||||
// Remove primary from the old route
|
|
||||||
r.IsPrimary = false
|
|
||||||
err = tx.Save(&r).Error
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("error disabling new primary route")
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("hostname", newPrimary.Node.Hostname).
|
|
||||||
Msg("removed primary from old route")
|
|
||||||
|
|
||||||
// Set primary for the new primary
|
|
||||||
newPrimary.IsPrimary = true
|
newPrimary.IsPrimary = true
|
||||||
err = tx.Save(&newPrimary).Error
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("error enabling new primary route")
|
|
||||||
|
|
||||||
return nil, err
|
return &failover{
|
||||||
|
old: routeToReplace,
|
||||||
|
new: newPrimary,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("hostname", newPrimary.Node.Hostname).
|
|
||||||
Msg("set primary to new route")
|
|
||||||
|
|
||||||
// Return a list of the machinekeys of the changed nodes.
|
|
||||||
return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||||
aclPolicy *policy.ACLPolicy,
|
aclPolicy *policy.ACLPolicy,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
) (*types.StateUpdate, error) {
|
) error {
|
||||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
|
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -610,9 +596,9 @@ func EnableAutoApprovedRoutes(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
aclPolicy *policy.ACLPolicy,
|
aclPolicy *policy.ACLPolicy,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
) (*types.StateUpdate, error) {
|
) error {
|
||||||
if len(node.IPAddresses) == 0 {
|
if len(node.IPAddresses) == 0 {
|
||||||
return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := GetNodeAdvertisedRoutes(tx, node)
|
routes, err := GetNodeAdvertisedRoutes(tx, node)
|
||||||
|
@ -623,7 +609,7 @@ func EnableAutoApprovedRoutes(
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Msg("Could not get advertised routes for node")
|
Msg("Could not get advertised routes for node")
|
||||||
|
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
|
log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
|
||||||
|
@ -641,10 +627,10 @@ func EnableAutoApprovedRoutes(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).
|
log.Err(err).
|
||||||
Str("advertisedRoute", advertisedRoute.String()).
|
Str("advertisedRoute", advertisedRoute.String()).
|
||||||
Uint64("nodeId", node.ID).
|
Uint64("nodeId", node.ID.Uint64()).
|
||||||
Msg("Failed to resolve autoApprovers for advertised route")
|
Msg("Failed to resolve autoApprovers for advertised route")
|
||||||
|
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
|
@ -665,7 +651,7 @@ func EnableAutoApprovedRoutes(
|
||||||
Str("alias", approvedAlias).
|
Str("alias", approvedAlias).
|
||||||
Msg("Failed to expand alias when processing autoApprovers policy")
|
Msg("Failed to expand alias when processing autoApprovers policy")
|
||||||
|
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||||
|
@ -676,25 +662,17 @@ func EnableAutoApprovedRoutes(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
update := &types.StateUpdate{
|
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: types.Nodes{},
|
|
||||||
Message: "created in db.EnableAutoApprovedRoutes",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, approvedRoute := range approvedRoutes {
|
for _, approvedRoute := range approvedRoutes {
|
||||||
perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID))
|
_, err := EnableRoute(tx, uint64(approvedRoute.ID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).
|
log.Err(err).
|
||||||
Str("approvedRoute", approvedRoute.String()).
|
Str("approvedRoute", approvedRoute.String()).
|
||||||
Uint64("nodeId", node.ID).
|
Uint64("nodeId", node.ID.Uint64()).
|
||||||
Msg("Failed to enable approved route")
|
Msg("Failed to enable approved route")
|
||||||
|
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return update, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,6 @@ import (
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestGetRoutes(c *check.C) {
|
func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
|
@ -262,7 +261,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// TODO(kradalby): check stateupdate
|
// TODO(kradalby): check stateupdate
|
||||||
_, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{})
|
_, err = db.DeleteRoute(uint64(routes[0].ID), nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
|
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
|
||||||
|
@ -272,20 +271,13 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
|
|
||||||
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
|
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
|
||||||
|
|
||||||
func TestFailoverRoute(t *testing.T) {
|
func TestFailoverRouteTx(t *testing.T) {
|
||||||
machineKeys := []key.MachinePublic{
|
|
||||||
key.NewMachine().Public(),
|
|
||||||
key.NewMachine().Public(),
|
|
||||||
key.NewMachine().Public(),
|
|
||||||
key.NewMachine().Public(),
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
failingRoute types.Route
|
failingRoute types.Route
|
||||||
routes types.Routes
|
routes types.Routes
|
||||||
isConnected map[key.MachinePublic]bool
|
isConnected types.NodeConnectedMap
|
||||||
want []key.MachinePublic
|
want []types.NodeID
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
|
@ -301,10 +293,8 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{},
|
||||||
MachineKey: machineKeys[0],
|
|
||||||
},
|
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
},
|
},
|
||||||
routes: types.Routes{},
|
routes: types.Routes{},
|
||||||
|
@ -317,10 +307,8 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
},
|
},
|
||||||
Prefix: ipp("0.0.0.0/0"),
|
Prefix: ipp("0.0.0.0/0"),
|
||||||
Node: types.Node{
|
Node: types.Node{},
|
||||||
MachineKey: machineKeys[0],
|
|
||||||
},
|
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
},
|
},
|
||||||
routes: types.Routes{},
|
routes: types.Routes{},
|
||||||
|
@ -335,7 +323,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
},
|
},
|
||||||
|
@ -346,7 +334,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
},
|
},
|
||||||
|
@ -362,7 +350,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -374,7 +362,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -385,19 +373,19 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[1],
|
ID: 2,
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
isConnected: map[key.MachinePublic]bool{
|
isConnected: types.NodeConnectedMap{
|
||||||
machineKeys[0]: false,
|
1: false,
|
||||||
machineKeys[1]: true,
|
2: true,
|
||||||
},
|
},
|
||||||
want: []key.MachinePublic{
|
want: []types.NodeID{
|
||||||
machineKeys[0],
|
1,
|
||||||
machineKeys[1],
|
2,
|
||||||
},
|
},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
@ -409,7 +397,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -421,7 +409,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -432,7 +420,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[1],
|
ID: 2,
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -449,7 +437,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[1],
|
ID: 2,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -461,7 +449,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -472,7 +460,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[1],
|
ID: 2,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -483,20 +471,19 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[2],
|
ID: 3,
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
isConnected: map[key.MachinePublic]bool{
|
isConnected: types.NodeConnectedMap{
|
||||||
machineKeys[0]: true,
|
1: true,
|
||||||
machineKeys[1]: true,
|
2: true,
|
||||||
machineKeys[2]: true,
|
3: true,
|
||||||
},
|
},
|
||||||
want: []key.MachinePublic{
|
want: []types.NodeID{
|
||||||
machineKeys[1],
|
2, 1,
|
||||||
machineKeys[0],
|
|
||||||
},
|
},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
@ -508,7 +495,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -520,7 +507,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -532,15 +519,15 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[3],
|
ID: 4,
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
isConnected: map[key.MachinePublic]bool{
|
isConnected: types.NodeConnectedMap{
|
||||||
machineKeys[0]: true,
|
1: true,
|
||||||
machineKeys[3]: false,
|
4: false,
|
||||||
},
|
},
|
||||||
want: nil,
|
want: nil,
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -553,7 +540,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -565,7 +552,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -577,7 +564,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[3],
|
ID: 4,
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -588,20 +575,20 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[1],
|
ID: 2,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
isConnected: map[key.MachinePublic]bool{
|
isConnected: types.NodeConnectedMap{
|
||||||
machineKeys[0]: false,
|
1: false,
|
||||||
machineKeys[1]: true,
|
2: true,
|
||||||
machineKeys[3]: false,
|
4: false,
|
||||||
},
|
},
|
||||||
want: []key.MachinePublic{
|
want: []types.NodeID{
|
||||||
machineKeys[0],
|
1,
|
||||||
machineKeys[1],
|
2,
|
||||||
},
|
},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
@ -613,7 +600,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -625,7 +612,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[0],
|
ID: 1,
|
||||||
},
|
},
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -637,7 +624,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
},
|
},
|
||||||
Prefix: ipp("10.0.0.0/24"),
|
Prefix: ipp("10.0.0.0/24"),
|
||||||
Node: types.Node{
|
Node: types.Node{
|
||||||
MachineKey: machineKeys[1],
|
ID: 2,
|
||||||
},
|
},
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
|
@ -670,8 +657,8 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) {
|
got, err := Write(db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||||
return failoverRoute(tx, tt.isConnected, &tt.failingRoute)
|
return failoverRouteTx(tx, tt.isConnected, &tt.failingRoute)
|
||||||
})
|
})
|
||||||
|
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
|
@ -687,230 +674,177 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// func TestDisableRouteFailover(t *testing.T) {
|
func TestFailoverRoute(t *testing.T) {
|
||||||
// machineKeys := []key.MachinePublic{
|
r := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
|
||||||
// key.NewMachine().Public(),
|
return types.Route{
|
||||||
// key.NewMachine().Public(),
|
Model: gorm.Model{
|
||||||
// key.NewMachine().Public(),
|
ID: id,
|
||||||
// key.NewMachine().Public(),
|
},
|
||||||
// }
|
Node: types.Node{
|
||||||
|
ID: nid,
|
||||||
|
},
|
||||||
|
Prefix: prefix,
|
||||||
|
Enabled: enabled,
|
||||||
|
IsPrimary: primary,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rp := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
|
||||||
|
ro := r(id, nid, prefix, enabled, primary)
|
||||||
|
return &ro
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
failingRoute types.Route
|
||||||
|
routes types.Routes
|
||||||
|
isConnected types.NodeConnectedMap
|
||||||
|
want *failover
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no-route",
|
||||||
|
failingRoute: types.Route{},
|
||||||
|
routes: types.Routes{},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-prime",
|
||||||
|
failingRoute: r(1, 1, ipp("10.0.0.0/24"), false, false),
|
||||||
|
|
||||||
// tests := []struct {
|
routes: types.Routes{},
|
||||||
// name string
|
want: nil,
|
||||||
// nodes types.Nodes
|
},
|
||||||
|
{
|
||||||
|
name: "exit-node",
|
||||||
|
failingRoute: r(1, 1, ipp("0.0.0.0/0"), false, true),
|
||||||
|
routes: types.Routes{},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-failover-single-route",
|
||||||
|
failingRoute: r(1, 1, ipp("10.0.0.0/24"), false, true),
|
||||||
|
routes: types.Routes{
|
||||||
|
r(1, 1, ipp("10.0.0.0/24"), false, true),
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failover-primary",
|
||||||
|
failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
|
routes: types.Routes{
|
||||||
|
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
|
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
|
},
|
||||||
|
isConnected: types.NodeConnectedMap{
|
||||||
|
1: false,
|
||||||
|
2: true,
|
||||||
|
},
|
||||||
|
want: &failover{
|
||||||
|
old: rp(1, 1, ipp("10.0.0.0/24"), true, false),
|
||||||
|
new: rp(2, 2, ipp("10.0.0.0/24"), true, true),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failover-none-primary",
|
||||||
|
failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, false),
|
||||||
|
routes: types.Routes{
|
||||||
|
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
|
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failover-primary-multi-route",
|
||||||
|
failingRoute: r(2, 2, ipp("10.0.0.0/24"), true, true),
|
||||||
|
routes: types.Routes{
|
||||||
|
r(1, 1, ipp("10.0.0.0/24"), true, false),
|
||||||
|
r(2, 2, ipp("10.0.0.0/24"), true, true),
|
||||||
|
r(3, 3, ipp("10.0.0.0/24"), true, false),
|
||||||
|
},
|
||||||
|
isConnected: types.NodeConnectedMap{
|
||||||
|
1: true,
|
||||||
|
2: true,
|
||||||
|
3: true,
|
||||||
|
},
|
||||||
|
want: &failover{
|
||||||
|
old: rp(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
|
new: rp(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failover-primary-no-online",
|
||||||
|
failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
|
routes: types.Routes{
|
||||||
|
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
|
r(2, 4, ipp("10.0.0.0/24"), true, false),
|
||||||
|
},
|
||||||
|
isConnected: types.NodeConnectedMap{
|
||||||
|
1: true,
|
||||||
|
4: false,
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failover-primary-one-not-online",
|
||||||
|
failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
|
routes: types.Routes{
|
||||||
|
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
|
r(2, 4, ipp("10.0.0.0/24"), true, false),
|
||||||
|
r(3, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
|
},
|
||||||
|
isConnected: types.NodeConnectedMap{
|
||||||
|
1: false,
|
||||||
|
2: true,
|
||||||
|
4: false,
|
||||||
|
},
|
||||||
|
want: &failover{
|
||||||
|
old: rp(1, 1, ipp("10.0.0.0/24"), true, false),
|
||||||
|
new: rp(3, 2, ipp("10.0.0.0/24"), true, true),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failover-primary-none-enabled",
|
||||||
|
failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
|
routes: types.Routes{
|
||||||
|
r(1, 1, ipp("10.0.0.0/24"), true, false),
|
||||||
|
r(2, 2, ipp("10.0.0.0/24"), false, true),
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// routeID uint64
|
cmps := append(
|
||||||
// isConnected map[key.MachinePublic]bool
|
util.Comparers,
|
||||||
|
cmp.Comparer(func(x, y types.IPPrefix) bool {
|
||||||
|
return netip.Prefix(x) == netip.Prefix(y)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
// wantMachineKey key.MachinePublic
|
for _, tt := range tests {
|
||||||
// wantErr string
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// }{
|
gotf := failoverRoute(tt.isConnected, &tt.failingRoute, tt.routes)
|
||||||
// {
|
|
||||||
// name: "single-route",
|
|
||||||
// nodes: types.Nodes{
|
|
||||||
// &types.Node{
|
|
||||||
// ID: 0,
|
|
||||||
// MachineKey: machineKeys[0],
|
|
||||||
// Routes: []types.Route{
|
|
||||||
// {
|
|
||||||
// Model: gorm.Model{
|
|
||||||
// ID: 1,
|
|
||||||
// },
|
|
||||||
// Prefix: ipp("10.0.0.0/24"),
|
|
||||||
// Node: types.Node{
|
|
||||||
// MachineKey: machineKeys[0],
|
|
||||||
// },
|
|
||||||
// IsPrimary: true,
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// Hostinfo: &tailcfg.Hostinfo{
|
|
||||||
// RoutableIPs: []netip.Prefix{
|
|
||||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// routeID: 1,
|
|
||||||
// wantMachineKey: machineKeys[0],
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// name: "failover-simple",
|
|
||||||
// nodes: types.Nodes{
|
|
||||||
// &types.Node{
|
|
||||||
// ID: 0,
|
|
||||||
// MachineKey: machineKeys[0],
|
|
||||||
// Routes: []types.Route{
|
|
||||||
// {
|
|
||||||
// Model: gorm.Model{
|
|
||||||
// ID: 1,
|
|
||||||
// },
|
|
||||||
// Prefix: ipp("10.0.0.0/24"),
|
|
||||||
// IsPrimary: true,
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// Hostinfo: &tailcfg.Hostinfo{
|
|
||||||
// RoutableIPs: []netip.Prefix{
|
|
||||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// &types.Node{
|
|
||||||
// ID: 1,
|
|
||||||
// MachineKey: machineKeys[1],
|
|
||||||
// Routes: []types.Route{
|
|
||||||
// {
|
|
||||||
// Model: gorm.Model{
|
|
||||||
// ID: 2,
|
|
||||||
// },
|
|
||||||
// Prefix: ipp("10.0.0.0/24"),
|
|
||||||
// IsPrimary: false,
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// Hostinfo: &tailcfg.Hostinfo{
|
|
||||||
// RoutableIPs: []netip.Prefix{
|
|
||||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// routeID: 1,
|
|
||||||
// wantMachineKey: machineKeys[1],
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// name: "no-failover-offline",
|
|
||||||
// nodes: types.Nodes{
|
|
||||||
// &types.Node{
|
|
||||||
// ID: 0,
|
|
||||||
// MachineKey: machineKeys[0],
|
|
||||||
// Routes: []types.Route{
|
|
||||||
// {
|
|
||||||
// Model: gorm.Model{
|
|
||||||
// ID: 1,
|
|
||||||
// },
|
|
||||||
// Prefix: ipp("10.0.0.0/24"),
|
|
||||||
// IsPrimary: true,
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// Hostinfo: &tailcfg.Hostinfo{
|
|
||||||
// RoutableIPs: []netip.Prefix{
|
|
||||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// &types.Node{
|
|
||||||
// ID: 1,
|
|
||||||
// MachineKey: machineKeys[1],
|
|
||||||
// Routes: []types.Route{
|
|
||||||
// {
|
|
||||||
// Model: gorm.Model{
|
|
||||||
// ID: 2,
|
|
||||||
// },
|
|
||||||
// Prefix: ipp("10.0.0.0/24"),
|
|
||||||
// IsPrimary: false,
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// Hostinfo: &tailcfg.Hostinfo{
|
|
||||||
// RoutableIPs: []netip.Prefix{
|
|
||||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// isConnected: map[key.MachinePublic]bool{
|
|
||||||
// machineKeys[0]: true,
|
|
||||||
// machineKeys[1]: false,
|
|
||||||
// },
|
|
||||||
// routeID: 1,
|
|
||||||
// wantMachineKey: machineKeys[1],
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// name: "failover-to-online",
|
|
||||||
// nodes: types.Nodes{
|
|
||||||
// &types.Node{
|
|
||||||
// ID: 0,
|
|
||||||
// MachineKey: machineKeys[0],
|
|
||||||
// Routes: []types.Route{
|
|
||||||
// {
|
|
||||||
// Model: gorm.Model{
|
|
||||||
// ID: 1,
|
|
||||||
// },
|
|
||||||
// Prefix: ipp("10.0.0.0/24"),
|
|
||||||
// IsPrimary: true,
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// Hostinfo: &tailcfg.Hostinfo{
|
|
||||||
// RoutableIPs: []netip.Prefix{
|
|
||||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// &types.Node{
|
|
||||||
// ID: 1,
|
|
||||||
// MachineKey: machineKeys[1],
|
|
||||||
// Routes: []types.Route{
|
|
||||||
// {
|
|
||||||
// Model: gorm.Model{
|
|
||||||
// ID: 2,
|
|
||||||
// },
|
|
||||||
// Prefix: ipp("10.0.0.0/24"),
|
|
||||||
// IsPrimary: false,
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// Hostinfo: &tailcfg.Hostinfo{
|
|
||||||
// RoutableIPs: []netip.Prefix{
|
|
||||||
// netip.MustParsePrefix("10.0.0.0/24"),
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// isConnected: map[key.MachinePublic]bool{
|
|
||||||
// machineKeys[0]: true,
|
|
||||||
// machineKeys[1]: true,
|
|
||||||
// },
|
|
||||||
// routeID: 1,
|
|
||||||
// wantMachineKey: machineKeys[1],
|
|
||||||
// },
|
|
||||||
// }
|
|
||||||
|
|
||||||
// for _, tt := range tests {
|
if tt.want == nil && gotf != nil {
|
||||||
// t.Run(tt.name, func(t *testing.T) {
|
t.Fatalf("expected nil, got %+v", gotf)
|
||||||
// datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "")
|
}
|
||||||
// assert.NoError(t, err)
|
|
||||||
|
|
||||||
// // bootstrap db
|
if gotf == nil && tt.want != nil {
|
||||||
// datab.DB.Transaction(func(tx *gorm.DB) error {
|
t.Fatalf("expected %+v, got nil", tt.want)
|
||||||
// for _, node := range tt.nodes {
|
}
|
||||||
// err := tx.Save(node).Error
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
|
|
||||||
// _, err = SaveNodeRoutes(tx, node)
|
if tt.want != nil && gotf != nil {
|
||||||
// if err != nil {
|
want := map[string]*types.Route{
|
||||||
// return err
|
"new": tt.want.new,
|
||||||
// }
|
"old": tt.want.old,
|
||||||
// }
|
}
|
||||||
|
|
||||||
// return nil
|
got := map[string]*types.Route{
|
||||||
// })
|
"new": gotf.new,
|
||||||
|
"old": gotf.old,
|
||||||
|
}
|
||||||
|
|
||||||
// got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
if diff := cmp.Diff(want, got, cmps...); diff != "" {
|
||||||
// return DisableRoute(tx, tt.routeID, tt.isConnected)
|
t.Fatalf("failoverRoute unexpected result (-want +got):\n%s", diff)
|
||||||
// })
|
}
|
||||||
|
}
|
||||||
// // if (err.Error() != "") != tt.wantErr {
|
})
|
||||||
// // t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
|
}
|
||||||
|
}
|
||||||
// // return
|
|
||||||
// // }
|
|
||||||
|
|
||||||
// if len(got.ChangeNodes) != 1 {
|
|
||||||
// t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes))
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" {
|
|
||||||
// t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff)
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
|
@ -222,7 +222,7 @@ func (api headscaleV1APIServer) GetNode(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.GetNodeRequest,
|
request *v1.GetNodeRequest,
|
||||||
) (*v1.GetNodeResponse, error) {
|
) (*v1.GetNodeResponse, error) {
|
||||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -231,7 +231,7 @@ func (api headscaleV1APIServer) GetNode(
|
||||||
|
|
||||||
// Populate the online field based on
|
// Populate the online field based on
|
||||||
// currently connected nodes.
|
// currently connected nodes.
|
||||||
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
|
resp.Online = api.h.nodeNotifier.IsConnected(node.ID)
|
||||||
|
|
||||||
return &v1.GetNodeResponse{Node: resp}, nil
|
return &v1.GetNodeResponse{Node: resp}, nil
|
||||||
}
|
}
|
||||||
|
@ -248,12 +248,12 @@ func (api headscaleV1APIServer) SetTags(
|
||||||
}
|
}
|
||||||
|
|
||||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
err := db.SetTags(tx, request.GetNodeId(), request.GetTags())
|
err := db.SetTags(tx, types.NodeID(request.GetNodeId()), request.GetTags())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.GetNodeByID(tx, request.GetNodeId())
|
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &v1.SetTagsResponse{
|
return &v1.SetTagsResponse{
|
||||||
|
@ -261,15 +261,12 @@ func (api headscaleV1APIServer) SetTags(
|
||||||
}, status.Error(codes.InvalidArgument, err.Error())
|
}, status.Error(codes.InvalidArgument, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
|
||||||
|
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
|
||||||
Type: types.StatePeerChanged,
|
Type: types.StatePeerChanged,
|
||||||
ChangeNodes: types.Nodes{node},
|
ChangeNodes: []types.NodeID{node.ID},
|
||||||
Message: "called from api.SetTags",
|
Message: "called from api.SetTags",
|
||||||
}
|
}, node.ID)
|
||||||
if stateUpdate.Valid() {
|
|
||||||
ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
|
@ -296,12 +293,12 @@ func (api headscaleV1APIServer) DeleteNode(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteNodeRequest,
|
request *v1.DeleteNodeRequest,
|
||||||
) (*v1.DeleteNodeResponse, error) {
|
) (*v1.DeleteNodeResponse, error) {
|
||||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.db.DeleteNode(
|
changedNodes, err := api.h.db.DeleteNode(
|
||||||
node,
|
node,
|
||||||
api.h.nodeNotifier.ConnectedMap(),
|
api.h.nodeNotifier.ConnectedMap(),
|
||||||
)
|
)
|
||||||
|
@ -309,13 +306,17 @@ func (api headscaleV1APIServer) DeleteNode(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
|
||||||
|
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
Type: types.StatePeerRemoved,
|
Type: types.StatePeerRemoved,
|
||||||
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
|
Removed: []types.NodeID{node.ID},
|
||||||
}
|
})
|
||||||
if stateUpdate.Valid() {
|
|
||||||
ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
|
if changedNodes != nil {
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, stateUpdate)
|
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
ChangeNodes: changedNodes,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.DeleteNodeResponse{}, nil
|
return &v1.DeleteNodeResponse{}, nil
|
||||||
|
@ -330,33 +331,27 @@ func (api headscaleV1APIServer) ExpireNode(
|
||||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
db.NodeSetExpiry(
|
db.NodeSetExpiry(
|
||||||
tx,
|
tx,
|
||||||
request.GetNodeId(),
|
types.NodeID(request.GetNodeId()),
|
||||||
now,
|
now,
|
||||||
)
|
)
|
||||||
|
|
||||||
return db.GetNodeByID(tx, request.GetNodeId())
|
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
selfUpdate := types.StateUpdate{
|
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
|
||||||
Type: types.StateSelfUpdate,
|
api.h.nodeNotifier.NotifyByMachineKey(
|
||||||
ChangeNodes: types.Nodes{node},
|
ctx,
|
||||||
}
|
types.StateUpdate{
|
||||||
if selfUpdate.Valid() {
|
Type: types.StateSelfUpdate,
|
||||||
ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
|
ChangeNodes: []types.NodeID{node.ID},
|
||||||
api.h.nodeNotifier.NotifyByMachineKey(
|
},
|
||||||
ctx,
|
node.ID)
|
||||||
selfUpdate,
|
|
||||||
node.MachineKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
stateUpdate := types.StateUpdateExpire(node.ID, now)
|
ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
|
||||||
if stateUpdate.Valid() {
|
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID)
|
||||||
ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
|
@ -380,21 +375,18 @@ func (api headscaleV1APIServer) RenameNode(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.GetNodeByID(tx, request.GetNodeId())
|
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
stateUpdate := types.StateUpdate{
|
ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
|
||||||
|
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
|
||||||
Type: types.StatePeerChanged,
|
Type: types.StatePeerChanged,
|
||||||
ChangeNodes: types.Nodes{node},
|
ChangeNodes: []types.NodeID{node.ID},
|
||||||
Message: "called from api.RenameNode",
|
Message: "called from api.RenameNode",
|
||||||
}
|
}, node.ID)
|
||||||
if stateUpdate.Valid() {
|
|
||||||
ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
|
@ -423,7 +415,7 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
|
|
||||||
// Populate the online field based on
|
// Populate the online field based on
|
||||||
// currently connected nodes.
|
// currently connected nodes.
|
||||||
resp.Online = isConnected[node.MachineKey]
|
resp.Online = isConnected[node.ID]
|
||||||
|
|
||||||
response[index] = resp
|
response[index] = resp
|
||||||
}
|
}
|
||||||
|
@ -446,7 +438,7 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
|
|
||||||
// Populate the online field based on
|
// Populate the online field based on
|
||||||
// currently connected nodes.
|
// currently connected nodes.
|
||||||
resp.Online = isConnected[node.MachineKey]
|
resp.Online = isConnected[node.ID]
|
||||||
|
|
||||||
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
|
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
|
||||||
node,
|
node,
|
||||||
|
@ -463,7 +455,7 @@ func (api headscaleV1APIServer) MoveNode(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.MoveNodeRequest,
|
request *v1.MoveNodeRequest,
|
||||||
) (*v1.MoveNodeResponse, error) {
|
) (*v1.MoveNodeResponse, error) {
|
||||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -503,7 +495,7 @@ func (api headscaleV1APIServer) EnableRoute(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if update != nil && update.Valid() {
|
if update != nil {
|
||||||
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
|
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
|
||||||
api.h.nodeNotifier.NotifyAll(
|
api.h.nodeNotifier.NotifyAll(
|
||||||
ctx, *update)
|
ctx, *update)
|
||||||
|
@ -516,17 +508,19 @@ func (api headscaleV1APIServer) DisableRoute(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DisableRouteRequest,
|
request *v1.DisableRouteRequest,
|
||||||
) (*v1.DisableRouteResponse, error) {
|
) (*v1.DisableRouteResponse, error) {
|
||||||
isConnected := api.h.nodeNotifier.ConnectedMap()
|
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.ConnectedMap())
|
||||||
return db.DisableRoute(tx, request.GetRouteId(), isConnected)
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if update != nil && update.Valid() {
|
if update != nil {
|
||||||
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
|
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, *update)
|
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
ChangeNodes: update,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.DisableRouteResponse{}, nil
|
return &v1.DisableRouteResponse{}, nil
|
||||||
|
@ -536,7 +530,7 @@ func (api headscaleV1APIServer) GetNodeRoutes(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.GetNodeRoutesRequest,
|
request *v1.GetNodeRoutesRequest,
|
||||||
) (*v1.GetNodeRoutesResponse, error) {
|
) (*v1.GetNodeRoutesResponse, error) {
|
||||||
node, err := api.h.db.GetNodeByID(request.GetNodeId())
|
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -556,16 +550,19 @@ func (api headscaleV1APIServer) DeleteRoute(
|
||||||
request *v1.DeleteRouteRequest,
|
request *v1.DeleteRouteRequest,
|
||||||
) (*v1.DeleteRouteResponse, error) {
|
) (*v1.DeleteRouteResponse, error) {
|
||||||
isConnected := api.h.nodeNotifier.ConnectedMap()
|
isConnected := api.h.nodeNotifier.ConnectedMap()
|
||||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||||
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
|
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if update != nil && update.Valid() {
|
if update != nil {
|
||||||
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
|
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
|
||||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, *update)
|
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
ChangeNodes: update,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.DeleteRouteResponse{}, nil
|
return &v1.DeleteRouteResponse{}, nil
|
||||||
|
|
|
@ -68,12 +68,6 @@ func (h *Headscale) KeyHandler(
|
||||||
Msg("could not get capability version")
|
Msg("could not get capability version")
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
writer.WriteHeader(http.StatusInternalServerError)
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -82,19 +76,6 @@ func (h *Headscale) KeyHandler(
|
||||||
Str("handler", "/key").
|
Str("handler", "/key").
|
||||||
Int("cap_ver", int(capVer)).
|
Int("cap_ver", int(capVer)).
|
||||||
Msg("New noise client")
|
Msg("New noise client")
|
||||||
if err != nil {
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, err := writer.Write([]byte("Wrong params"))
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TS2021 (Tailscale v2 protocol) requires to have a different key
|
// TS2021 (Tailscale v2 protocol) requires to have a different key
|
||||||
if capVer >= NoiseCapabilityVersion {
|
if capVer >= NoiseCapabilityVersion {
|
||||||
|
|
|
@ -16,12 +16,12 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
mapset "github.com/deckarep/golang-set/v2"
|
mapset "github.com/deckarep/golang-set/v2"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
"tailscale.com/envknob"
|
"tailscale.com/envknob"
|
||||||
"tailscale.com/smallzstd"
|
"tailscale.com/smallzstd"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -51,21 +51,14 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
|
||||||
type Mapper struct {
|
type Mapper struct {
|
||||||
// Configuration
|
// Configuration
|
||||||
// TODO(kradalby): figure out if this is the format we want this in
|
// TODO(kradalby): figure out if this is the format we want this in
|
||||||
derpMap *tailcfg.DERPMap
|
db *db.HSDatabase
|
||||||
baseDomain string
|
cfg *types.Config
|
||||||
dnsCfg *tailcfg.DNSConfig
|
derpMap *tailcfg.DERPMap
|
||||||
logtail bool
|
isLikelyConnected types.NodeConnectedMap
|
||||||
randomClientPort bool
|
|
||||||
|
|
||||||
uid string
|
uid string
|
||||||
created time.Time
|
created time.Time
|
||||||
seq uint64
|
seq uint64
|
||||||
|
|
||||||
// Map isnt concurrency safe, so we need to ensure
|
|
||||||
// only one func is accessing it over time.
|
|
||||||
mu sync.Mutex
|
|
||||||
peers map[uint64]*types.Node
|
|
||||||
patches map[uint64][]patch
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type patch struct {
|
type patch struct {
|
||||||
|
@ -74,35 +67,22 @@ type patch struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMapper(
|
func NewMapper(
|
||||||
node *types.Node,
|
db *db.HSDatabase,
|
||||||
peers types.Nodes,
|
cfg *types.Config,
|
||||||
derpMap *tailcfg.DERPMap,
|
derpMap *tailcfg.DERPMap,
|
||||||
baseDomain string,
|
isLikelyConnected types.NodeConnectedMap,
|
||||||
dnsCfg *tailcfg.DNSConfig,
|
|
||||||
logtail bool,
|
|
||||||
randomClientPort bool,
|
|
||||||
) *Mapper {
|
) *Mapper {
|
||||||
log.Debug().
|
|
||||||
Caller().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Msg("creating new mapper")
|
|
||||||
|
|
||||||
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||||
|
|
||||||
return &Mapper{
|
return &Mapper{
|
||||||
derpMap: derpMap,
|
db: db,
|
||||||
baseDomain: baseDomain,
|
cfg: cfg,
|
||||||
dnsCfg: dnsCfg,
|
derpMap: derpMap,
|
||||||
logtail: logtail,
|
isLikelyConnected: isLikelyConnected,
|
||||||
randomClientPort: randomClientPort,
|
|
||||||
|
|
||||||
uid: uid,
|
uid: uid,
|
||||||
created: time.Now(),
|
created: time.Now(),
|
||||||
seq: 0,
|
seq: 0,
|
||||||
|
|
||||||
// TODO: populate
|
|
||||||
peers: peers.IDMap(),
|
|
||||||
patches: make(map[uint64][]patch),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,11 +187,10 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||||
// It is a separate function to make testing easier.
|
// It is a separate function to make testing easier.
|
||||||
func (m *Mapper) fullMapResponse(
|
func (m *Mapper) fullMapResponse(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
|
peers types.Nodes,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
peers := nodeMapToList(m.peers)
|
|
||||||
|
|
||||||
resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
|
resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -219,14 +198,13 @@ func (m *Mapper) fullMapResponse(
|
||||||
|
|
||||||
err = appendPeerChanges(
|
err = appendPeerChanges(
|
||||||
resp,
|
resp,
|
||||||
|
true, // full change
|
||||||
pol,
|
pol,
|
||||||
node,
|
node,
|
||||||
capVer,
|
capVer,
|
||||||
peers,
|
peers,
|
||||||
peers,
|
peers,
|
||||||
m.baseDomain,
|
m.cfg,
|
||||||
m.dnsCfg,
|
|
||||||
m.randomClientPort,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -240,35 +218,25 @@ func (m *Mapper) FullMapResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
|
messages ...string,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
m.mu.Lock()
|
peers, err := m.ListPeers(node.ID)
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
peers := maps.Keys(m.peers)
|
|
||||||
peersWithPatches := maps.Keys(m.patches)
|
|
||||||
slices.Sort(peers)
|
|
||||||
slices.Sort(peersWithPatches)
|
|
||||||
|
|
||||||
if len(peersWithPatches) > 0 {
|
|
||||||
log.Debug().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Uints64("peers", peers).
|
|
||||||
Uints64("pending_patches", peersWithPatches).
|
|
||||||
Msgf("node requested full map response, but has pending patches")
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := m.fullMapResponse(node, pol, mapRequest.Version)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress)
|
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LiteMapResponse returns a MapResponse for the given node.
|
// ReadOnlyResponse returns a MapResponse for the given node.
|
||||||
// Lite means that the peers has been omitted, this is intended
|
// Lite means that the peers has been omitted, this is intended
|
||||||
// to be used to answer MapRequests with OmitPeers set to true.
|
// to be used to answer MapRequests with OmitPeers set to true.
|
||||||
func (m *Mapper) LiteMapResponse(
|
func (m *Mapper) ReadOnlyMapResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
|
@ -279,18 +247,6 @@ func (m *Mapper) LiteMapResponse(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
|
||||||
pol,
|
|
||||||
node,
|
|
||||||
nodeMapToList(m.peers),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
|
|
||||||
resp.SSHPolicy = sshPolicy
|
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,50 +276,74 @@ func (m *Mapper) DERPMapResponse(
|
||||||
func (m *Mapper) PeerChangedResponse(
|
func (m *Mapper) PeerChangedResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
changed types.Nodes,
|
changed map[types.NodeID]bool,
|
||||||
|
patches []*tailcfg.PeerChange,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
messages ...string,
|
messages ...string,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
// Update our internal map.
|
|
||||||
for _, node := range changed {
|
|
||||||
if patches, ok := m.patches[node.ID]; ok {
|
|
||||||
// preserve online status in case the patch has an outdated one
|
|
||||||
online := node.IsOnline
|
|
||||||
|
|
||||||
for _, p := range patches {
|
|
||||||
// TODO(kradalby): Figure if this needs to be sorted by timestamp
|
|
||||||
node.ApplyPeerChange(p.change)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the patches are not applied again later
|
|
||||||
delete(m.patches, node.ID)
|
|
||||||
|
|
||||||
node.IsOnline = online
|
|
||||||
}
|
|
||||||
|
|
||||||
m.peers[node.ID] = node
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := m.baseMapResponse()
|
resp := m.baseMapResponse()
|
||||||
|
|
||||||
err := appendPeerChanges(
|
peers, err := m.ListPeers(node.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var removedIDs []tailcfg.NodeID
|
||||||
|
var changedIDs []types.NodeID
|
||||||
|
for nodeID, nodeChanged := range changed {
|
||||||
|
if nodeChanged {
|
||||||
|
changedIDs = append(changedIDs, nodeID)
|
||||||
|
} else {
|
||||||
|
removedIDs = append(removedIDs, nodeID.NodeID())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
changedNodes := make(types.Nodes, 0, len(changedIDs))
|
||||||
|
for _, peer := range peers {
|
||||||
|
if slices.Contains(changedIDs, peer.ID) {
|
||||||
|
changedNodes = append(changedNodes, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = appendPeerChanges(
|
||||||
&resp,
|
&resp,
|
||||||
|
false, // partial change
|
||||||
pol,
|
pol,
|
||||||
node,
|
node,
|
||||||
mapRequest.Version,
|
mapRequest.Version,
|
||||||
nodeMapToList(m.peers),
|
peers,
|
||||||
changed,
|
changedNodes,
|
||||||
m.baseDomain,
|
m.cfg,
|
||||||
m.dnsCfg,
|
|
||||||
m.randomClientPort,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resp.PeersRemoved = removedIDs
|
||||||
|
|
||||||
|
// Sending patches as a part of a PeersChanged response
|
||||||
|
// is technically not suppose to be done, but they are
|
||||||
|
// applied after the PeersChanged. The patch list
|
||||||
|
// should _only_ contain Nodes that are not in the
|
||||||
|
// PeersChanged or PeersRemoved list and the caller
|
||||||
|
// should filter them out.
|
||||||
|
//
|
||||||
|
// From tailcfg docs:
|
||||||
|
// These are applied after Peers* above, but in practice the
|
||||||
|
// control server should only send these on their own, without
|
||||||
|
// the Peers* fields also set.
|
||||||
|
if patches != nil {
|
||||||
|
resp.PeersChangedPatch = patches
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the node itself, it might have changed, and particularly
|
||||||
|
// if there are no patches or changes, this is a self update.
|
||||||
|
tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp.Node = tailnode
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
|
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -375,71 +355,12 @@ func (m *Mapper) PeerChangedPatchResponse(
|
||||||
changed []*tailcfg.PeerChange,
|
changed []*tailcfg.PeerChange,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
sendUpdate := false
|
|
||||||
// patch the internal map
|
|
||||||
for _, change := range changed {
|
|
||||||
if peer, ok := m.peers[uint64(change.NodeID)]; ok {
|
|
||||||
peer.ApplyPeerChange(change)
|
|
||||||
sendUpdate = true
|
|
||||||
} else {
|
|
||||||
log.Trace().Str("node", node.Hostname).Msgf("Node with ID %s is missing from mapper for Node %s, saving patch for when node is available", change.NodeID, node.Hostname)
|
|
||||||
|
|
||||||
p := patch{
|
|
||||||
timestamp: time.Now(),
|
|
||||||
change: change,
|
|
||||||
}
|
|
||||||
|
|
||||||
if patches, ok := m.patches[uint64(change.NodeID)]; ok {
|
|
||||||
m.patches[uint64(change.NodeID)] = append(patches, p)
|
|
||||||
} else {
|
|
||||||
m.patches[uint64(change.NodeID)] = []patch{p}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !sendUpdate {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := m.baseMapResponse()
|
resp := m.baseMapResponse()
|
||||||
resp.PeersChangedPatch = changed
|
resp.PeersChangedPatch = changed
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): We need some integration tests for this.
|
|
||||||
func (m *Mapper) PeerRemovedResponse(
|
|
||||||
mapRequest tailcfg.MapRequest,
|
|
||||||
node *types.Node,
|
|
||||||
removed []tailcfg.NodeID,
|
|
||||||
) ([]byte, error) {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
// Some nodes might have been removed already
|
|
||||||
// so we dont want to ask downstream to remove
|
|
||||||
// twice, than can cause a panic in tailscaled.
|
|
||||||
notYetRemoved := []tailcfg.NodeID{}
|
|
||||||
|
|
||||||
// remove from our internal map
|
|
||||||
for _, id := range removed {
|
|
||||||
if _, ok := m.peers[uint64(id)]; ok {
|
|
||||||
notYetRemoved = append(notYetRemoved, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(m.peers, uint64(id))
|
|
||||||
delete(m.patches, uint64(id))
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := m.baseMapResponse()
|
|
||||||
resp.PeersRemoved = notYetRemoved
|
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Mapper) marshalMapResponse(
|
func (m *Mapper) marshalMapResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
resp *tailcfg.MapResponse,
|
resp *tailcfg.MapResponse,
|
||||||
|
@ -469,10 +390,8 @@ func (m *Mapper) marshalMapResponse(
|
||||||
switch {
|
switch {
|
||||||
case resp.Peers != nil && len(resp.Peers) > 0:
|
case resp.Peers != nil && len(resp.Peers) > 0:
|
||||||
responseType = "full"
|
responseType = "full"
|
||||||
case isSelfUpdate(messages...):
|
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
||||||
responseType = "self"
|
responseType = "self"
|
||||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil:
|
|
||||||
responseType = "lite"
|
|
||||||
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
|
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
|
||||||
responseType = "changed"
|
responseType = "changed"
|
||||||
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
|
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
|
||||||
|
@ -496,11 +415,11 @@ func (m *Mapper) marshalMapResponse(
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().UnixNano()
|
now := time.Now().Format("2006-01-02T15-04-05.999999999")
|
||||||
|
|
||||||
mapResponsePath := path.Join(
|
mapResponsePath := path.Join(
|
||||||
mPath,
|
mPath,
|
||||||
fmt.Sprintf("%d-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
|
fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
|
||||||
)
|
)
|
||||||
|
|
||||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||||
|
@ -574,7 +493,7 @@ func (m *Mapper) baseWithConfigMapResponse(
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
resp := m.baseMapResponse()
|
resp := m.baseMapResponse()
|
||||||
|
|
||||||
tailnode, err := tailNode(node, capVer, pol, m.dnsCfg, m.baseDomain, m.randomClientPort)
|
tailnode, err := tailNode(node, capVer, pol, m.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -582,7 +501,7 @@ func (m *Mapper) baseWithConfigMapResponse(
|
||||||
|
|
||||||
resp.DERPMap = m.derpMap
|
resp.DERPMap = m.derpMap
|
||||||
|
|
||||||
resp.Domain = m.baseDomain
|
resp.Domain = m.cfg.BaseDomain
|
||||||
|
|
||||||
// Do not instruct clients to collect services we do not
|
// Do not instruct clients to collect services we do not
|
||||||
// support or do anything with them
|
// support or do anything with them
|
||||||
|
@ -591,12 +510,26 @@ func (m *Mapper) baseWithConfigMapResponse(
|
||||||
resp.KeepAlive = false
|
resp.KeepAlive = false
|
||||||
|
|
||||||
resp.Debug = &tailcfg.Debug{
|
resp.Debug = &tailcfg.Debug{
|
||||||
DisableLogTail: !m.logtail,
|
DisableLogTail: !m.cfg.LogTail.Enabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
|
||||||
|
peers, err := m.db.ListPeers(nodeID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
online := m.isLikelyConnected[peer.ID]
|
||||||
|
peer.IsOnline = &online
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers, nil
|
||||||
|
}
|
||||||
|
|
||||||
func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
|
func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
|
||||||
ret := make(types.Nodes, 0)
|
ret := make(types.Nodes, 0)
|
||||||
|
|
||||||
|
@ -612,42 +545,41 @@ func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
|
||||||
func appendPeerChanges(
|
func appendPeerChanges(
|
||||||
resp *tailcfg.MapResponse,
|
resp *tailcfg.MapResponse,
|
||||||
|
|
||||||
|
fullChange bool,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
changed types.Nodes,
|
changed types.Nodes,
|
||||||
baseDomain string,
|
cfg *types.Config,
|
||||||
dnsCfg *tailcfg.DNSConfig,
|
|
||||||
randomClientPort bool,
|
|
||||||
) error {
|
) error {
|
||||||
fullChange := len(peers) == len(changed)
|
|
||||||
|
|
||||||
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
packetFilter, err := pol.CompileFilterRules(append(peers, node))
|
||||||
pol,
|
if err != nil {
|
||||||
node,
|
return err
|
||||||
peers,
|
}
|
||||||
)
|
|
||||||
|
sshPolicy, err := pol.CompileSSHPolicy(node, peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are filter rules present, see if there are any nodes that cannot
|
// If there are filter rules present, see if there are any nodes that cannot
|
||||||
// access eachother at all and remove them from the peers.
|
// access eachother at all and remove them from the peers.
|
||||||
if len(rules) > 0 {
|
if len(packetFilter) > 0 {
|
||||||
changed = policy.FilterNodesByACL(node, changed, rules)
|
changed = policy.FilterNodesByACL(node, changed, packetFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles := generateUserProfiles(node, changed, baseDomain)
|
profiles := generateUserProfiles(node, changed, cfg.BaseDomain)
|
||||||
|
|
||||||
dnsConfig := generateDNSConfig(
|
dnsConfig := generateDNSConfig(
|
||||||
dnsCfg,
|
cfg.DNSConfig,
|
||||||
baseDomain,
|
cfg.BaseDomain,
|
||||||
node,
|
node,
|
||||||
peers,
|
peers,
|
||||||
)
|
)
|
||||||
|
|
||||||
tailPeers, err := tailNodes(changed, capVer, pol, dnsCfg, baseDomain, randomClientPort)
|
tailPeers, err := tailNodes(changed, capVer, pol, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -663,19 +595,9 @@ func appendPeerChanges(
|
||||||
resp.PeersChanged = tailPeers
|
resp.PeersChanged = tailPeers
|
||||||
}
|
}
|
||||||
resp.DNSConfig = dnsConfig
|
resp.DNSConfig = dnsConfig
|
||||||
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
|
resp.PacketFilter = policy.ReduceFilterRules(node, packetFilter)
|
||||||
resp.UserProfiles = profiles
|
resp.UserProfiles = profiles
|
||||||
resp.SSHPolicy = sshPolicy
|
resp.SSHPolicy = sshPolicy
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isSelfUpdate(messages ...string) bool {
|
|
||||||
for _, message := range messages {
|
|
||||||
if strings.Contains(message, types.SelfUpdateIdentifier) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
|
@ -331,13 +331,10 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
node *types.Node
|
node *types.Node
|
||||||
peers types.Nodes
|
peers types.Nodes
|
||||||
|
|
||||||
baseDomain string
|
derpMap *tailcfg.DERPMap
|
||||||
dnsConfig *tailcfg.DNSConfig
|
cfg *types.Config
|
||||||
derpMap *tailcfg.DERPMap
|
want *tailcfg.MapResponse
|
||||||
logtail bool
|
wantErr bool
|
||||||
randomClientPort bool
|
|
||||||
want *tailcfg.MapResponse
|
|
||||||
wantErr bool
|
|
||||||
}{
|
}{
|
||||||
// {
|
// {
|
||||||
// name: "empty-node",
|
// name: "empty-node",
|
||||||
|
@ -349,15 +346,17 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
// wantErr: true,
|
// wantErr: true,
|
||||||
// },
|
// },
|
||||||
{
|
{
|
||||||
name: "no-pol-no-peers-map-response",
|
name: "no-pol-no-peers-map-response",
|
||||||
pol: &policy.ACLPolicy{},
|
pol: &policy.ACLPolicy{},
|
||||||
node: mini,
|
node: mini,
|
||||||
peers: types.Nodes{},
|
peers: types.Nodes{},
|
||||||
baseDomain: "",
|
derpMap: &tailcfg.DERPMap{},
|
||||||
dnsConfig: &tailcfg.DNSConfig{},
|
cfg: &types.Config{
|
||||||
derpMap: &tailcfg.DERPMap{},
|
BaseDomain: "",
|
||||||
logtail: false,
|
DNSConfig: &tailcfg.DNSConfig{},
|
||||||
randomClientPort: false,
|
LogTail: types.LogTailConfig{Enabled: false},
|
||||||
|
RandomizeClientPort: false,
|
||||||
|
},
|
||||||
want: &tailcfg.MapResponse{
|
want: &tailcfg.MapResponse{
|
||||||
Node: tailMini,
|
Node: tailMini,
|
||||||
KeepAlive: false,
|
KeepAlive: false,
|
||||||
|
@ -383,11 +382,13 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
peers: types.Nodes{
|
peers: types.Nodes{
|
||||||
peer1,
|
peer1,
|
||||||
},
|
},
|
||||||
baseDomain: "",
|
derpMap: &tailcfg.DERPMap{},
|
||||||
dnsConfig: &tailcfg.DNSConfig{},
|
cfg: &types.Config{
|
||||||
derpMap: &tailcfg.DERPMap{},
|
BaseDomain: "",
|
||||||
logtail: false,
|
DNSConfig: &tailcfg.DNSConfig{},
|
||||||
randomClientPort: false,
|
LogTail: types.LogTailConfig{Enabled: false},
|
||||||
|
RandomizeClientPort: false,
|
||||||
|
},
|
||||||
want: &tailcfg.MapResponse{
|
want: &tailcfg.MapResponse{
|
||||||
KeepAlive: false,
|
KeepAlive: false,
|
||||||
Node: tailMini,
|
Node: tailMini,
|
||||||
|
@ -424,11 +425,13 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
peer1,
|
peer1,
|
||||||
peer2,
|
peer2,
|
||||||
},
|
},
|
||||||
baseDomain: "",
|
derpMap: &tailcfg.DERPMap{},
|
||||||
dnsConfig: &tailcfg.DNSConfig{},
|
cfg: &types.Config{
|
||||||
derpMap: &tailcfg.DERPMap{},
|
BaseDomain: "",
|
||||||
logtail: false,
|
DNSConfig: &tailcfg.DNSConfig{},
|
||||||
randomClientPort: false,
|
LogTail: types.LogTailConfig{Enabled: false},
|
||||||
|
RandomizeClientPort: false,
|
||||||
|
},
|
||||||
want: &tailcfg.MapResponse{
|
want: &tailcfg.MapResponse{
|
||||||
KeepAlive: false,
|
KeepAlive: false,
|
||||||
Node: tailMini,
|
Node: tailMini,
|
||||||
|
@ -463,17 +466,15 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
mappy := NewMapper(
|
mappy := NewMapper(
|
||||||
tt.node,
|
nil,
|
||||||
tt.peers,
|
tt.cfg,
|
||||||
tt.derpMap,
|
tt.derpMap,
|
||||||
tt.baseDomain,
|
nil,
|
||||||
tt.dnsConfig,
|
|
||||||
tt.logtail,
|
|
||||||
tt.randomClientPort,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
got, err := mappy.fullMapResponse(
|
got, err := mappy.fullMapResponse(
|
||||||
tt.node,
|
tt.node,
|
||||||
|
tt.peers,
|
||||||
tt.pol,
|
tt.pol,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,12 +3,10 @@ package mapper
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
@ -17,9 +15,7 @@ func tailNodes(
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
dnsConfig *tailcfg.DNSConfig,
|
cfg *types.Config,
|
||||||
baseDomain string,
|
|
||||||
randomClientPort bool,
|
|
||||||
) ([]*tailcfg.Node, error) {
|
) ([]*tailcfg.Node, error) {
|
||||||
tNodes := make([]*tailcfg.Node, len(nodes))
|
tNodes := make([]*tailcfg.Node, len(nodes))
|
||||||
|
|
||||||
|
@ -28,9 +24,7 @@ func tailNodes(
|
||||||
node,
|
node,
|
||||||
capVer,
|
capVer,
|
||||||
pol,
|
pol,
|
||||||
dnsConfig,
|
cfg,
|
||||||
baseDomain,
|
|
||||||
randomClientPort,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -48,9 +42,7 @@ func tailNode(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
dnsConfig *tailcfg.DNSConfig,
|
cfg *types.Config,
|
||||||
baseDomain string,
|
|
||||||
randomClientPort bool,
|
|
||||||
) (*tailcfg.Node, error) {
|
) (*tailcfg.Node, error) {
|
||||||
addrs := node.IPAddresses.Prefixes()
|
addrs := node.IPAddresses.Prefixes()
|
||||||
|
|
||||||
|
@ -85,7 +77,7 @@ func tailNode(
|
||||||
keyExpiry = time.Time{}
|
keyExpiry = time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostname, err := node.GetFQDN(dnsConfig, baseDomain)
|
hostname, err := node.GetFQDN(cfg.DNSConfig, cfg.BaseDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -94,12 +86,10 @@ func tailNode(
|
||||||
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||||
|
|
||||||
tNode := tailcfg.Node{
|
tNode := tailcfg.Node{
|
||||||
ID: tailcfg.NodeID(node.ID), // this is the actual ID
|
ID: tailcfg.NodeID(node.ID), // this is the actual ID
|
||||||
StableID: tailcfg.StableNodeID(
|
StableID: node.ID.StableID(),
|
||||||
strconv.FormatUint(node.ID, util.Base10),
|
Name: hostname,
|
||||||
), // in headscale, unlike tailcontrol server, IDs are permanent
|
Cap: capVer,
|
||||||
Name: hostname,
|
|
||||||
Cap: capVer,
|
|
||||||
|
|
||||||
User: tailcfg.UserID(node.UserID),
|
User: tailcfg.UserID(node.UserID),
|
||||||
|
|
||||||
|
@ -133,7 +123,7 @@ func tailNode(
|
||||||
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
|
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
|
||||||
}
|
}
|
||||||
|
|
||||||
if randomClientPort {
|
if cfg.RandomizeClientPort {
|
||||||
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
|
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -143,7 +133,7 @@ func tailNode(
|
||||||
tailcfg.CapabilitySSH,
|
tailcfg.CapabilitySSH,
|
||||||
}
|
}
|
||||||
|
|
||||||
if randomClientPort {
|
if cfg.RandomizeClientPort {
|
||||||
tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrRandomizeClientPort)
|
tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrRandomizeClientPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -182,13 +182,16 @@ func TestTailNode(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := &types.Config{
|
||||||
|
BaseDomain: tt.baseDomain,
|
||||||
|
DNSConfig: tt.dnsConfig,
|
||||||
|
RandomizeClientPort: false,
|
||||||
|
}
|
||||||
got, err := tailNode(
|
got, err := tailNode(
|
||||||
tt.node,
|
tt.node,
|
||||||
0,
|
0,
|
||||||
tt.pol,
|
tt.pol,
|
||||||
tt.dnsConfig,
|
cfg,
|
||||||
tt.baseDomain,
|
|
||||||
false,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
|
|
|
@ -3,6 +3,7 @@ package hscontrol
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
@ -11,6 +12,7 @@ import (
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/h2c"
|
"golang.org/x/net/http2/h2c"
|
||||||
|
"gorm.io/gorm"
|
||||||
"tailscale.com/control/controlbase"
|
"tailscale.com/control/controlbase"
|
||||||
"tailscale.com/control/controlhttp"
|
"tailscale.com/control/controlhttp"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -163,3 +165,135 @@ func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
MinimumCapVersion tailcfg.CapabilityVersion = 58
|
||||||
|
)
|
||||||
|
|
||||||
|
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
|
||||||
|
//
|
||||||
|
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
|
||||||
|
// the clients when something in the network changes.
|
||||||
|
//
|
||||||
|
// The clients POST stuff like HostInfo and their Endpoints here, but
|
||||||
|
// only after their first request (marked with the ReadOnly field).
|
||||||
|
//
|
||||||
|
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
|
||||||
|
func (ns *noiseServer) NoisePollNetMapHandler(
|
||||||
|
writer http.ResponseWriter,
|
||||||
|
req *http.Request,
|
||||||
|
) {
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "NoisePollNetMap").
|
||||||
|
Msg("PollNetMapHandler called")
|
||||||
|
|
||||||
|
log.Trace().
|
||||||
|
Any("headers", req.Header).
|
||||||
|
Caller().
|
||||||
|
Msg("Headers")
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(req.Body)
|
||||||
|
|
||||||
|
mapRequest := tailcfg.MapRequest{}
|
||||||
|
if err := json.Unmarshal(body, &mapRequest); err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot parse MapRequest")
|
||||||
|
http.Error(writer, "Internal error", http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject unsupported versions
|
||||||
|
if mapRequest.Version < MinimumCapVersion {
|
||||||
|
log.Info().
|
||||||
|
Caller().
|
||||||
|
Int("min_version", int(MinimumCapVersion)).
|
||||||
|
Int("client_version", int(mapRequest.Version)).
|
||||||
|
Msg("unsupported client connected")
|
||||||
|
http.Error(writer, "Internal error", http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ns.nodeKey = mapRequest.NodeKey
|
||||||
|
|
||||||
|
node, err := ns.headscale.db.GetNodeByAnyKey(
|
||||||
|
ns.conn.Peer(),
|
||||||
|
mapRequest.NodeKey,
|
||||||
|
key.NodePublic{},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
log.Warn().
|
||||||
|
Str("handler", "NoisePollNetMap").
|
||||||
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
|
Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String())
|
||||||
|
http.Error(writer, "Internal error", http.StatusNotFound)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "NoisePollNetMap").
|
||||||
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
|
Msgf("Failed to fetch node from the database with node key: %s", mapRequest.NodeKey.String())
|
||||||
|
http.Error(writer, "Internal error", http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debug().
|
||||||
|
Str("handler", "NoisePollNetMap").
|
||||||
|
Str("node", node.Hostname).
|
||||||
|
Int("cap_ver", int(mapRequest.Version)).
|
||||||
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
|
Msg("A node sending a MapRequest with Noise protocol")
|
||||||
|
|
||||||
|
session := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node)
|
||||||
|
|
||||||
|
// If a streaming mapSession exists for this node, close it
|
||||||
|
// and start a new one.
|
||||||
|
if session.isStreaming() {
|
||||||
|
log.Debug().
|
||||||
|
Caller().
|
||||||
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
|
Int("cap_ver", int(mapRequest.Version)).
|
||||||
|
Msg("Aquiring lock to check stream")
|
||||||
|
ns.headscale.mapSessionMu.Lock()
|
||||||
|
if oldSession, ok := ns.headscale.mapSessions[node.ID]; ok {
|
||||||
|
log.Info().
|
||||||
|
Caller().
|
||||||
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
|
Msg("Node has an open streaming session, replacing")
|
||||||
|
oldSession.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
ns.headscale.mapSessions[node.ID] = session
|
||||||
|
ns.headscale.mapSessionMu.Unlock()
|
||||||
|
log.Debug().
|
||||||
|
Caller().
|
||||||
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
|
Int("cap_ver", int(mapRequest.Version)).
|
||||||
|
Msg("Releasing lock to check stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
session.serve()
|
||||||
|
|
||||||
|
if session.isStreaming() {
|
||||||
|
log.Debug().
|
||||||
|
Caller().
|
||||||
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
|
Int("cap_ver", int(mapRequest.Version)).
|
||||||
|
Msg("Aquiring lock to remove stream")
|
||||||
|
ns.headscale.mapSessionMu.Lock()
|
||||||
|
|
||||||
|
delete(ns.headscale.mapSessions, node.ID)
|
||||||
|
|
||||||
|
ns.headscale.mapSessionMu.Unlock()
|
||||||
|
log.Debug().
|
||||||
|
Caller().
|
||||||
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
|
Int("cap_ver", int(mapRequest.Version)).
|
||||||
|
Msg("Releasing lock to remove stream")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -3,52 +3,51 @@ package notifier
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"tailscale.com/types/key"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Notifier struct {
|
type Notifier struct {
|
||||||
l sync.RWMutex
|
l sync.RWMutex
|
||||||
nodes map[string]chan<- types.StateUpdate
|
nodes map[types.NodeID]chan<- types.StateUpdate
|
||||||
connected map[key.MachinePublic]bool
|
connected types.NodeConnectedMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNotifier() *Notifier {
|
func NewNotifier() *Notifier {
|
||||||
return &Notifier{
|
return &Notifier{
|
||||||
nodes: make(map[string]chan<- types.StateUpdate),
|
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
||||||
connected: make(map[key.MachinePublic]bool),
|
connected: make(types.NodeConnectedMap),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
|
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||||
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
|
log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to add node")
|
||||||
defer log.Trace().
|
defer log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("key", machineKey.ShortString()).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Msg("releasing lock to add node")
|
Msg("releasing lock to add node")
|
||||||
|
|
||||||
n.l.Lock()
|
n.l.Lock()
|
||||||
defer n.l.Unlock()
|
defer n.l.Unlock()
|
||||||
|
|
||||||
n.nodes[machineKey.String()] = c
|
n.nodes[nodeID] = c
|
||||||
n.connected[machineKey] = true
|
n.connected[nodeID] = true
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("machine_key", machineKey.ShortString()).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Int("open_chans", len(n.nodes)).
|
Int("open_chans", len(n.nodes)).
|
||||||
Msg("Added new channel")
|
Msg("Added new channel")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
|
func (n *Notifier) RemoveNode(nodeID types.NodeID) {
|
||||||
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
|
log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to remove node")
|
||||||
defer log.Trace().
|
defer log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("key", machineKey.ShortString()).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Msg("releasing lock to remove node")
|
Msg("releasing lock to remove node")
|
||||||
|
|
||||||
n.l.Lock()
|
n.l.Lock()
|
||||||
|
@ -58,26 +57,32 @@ func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(n.nodes, machineKey.String())
|
delete(n.nodes, nodeID)
|
||||||
n.connected[machineKey] = false
|
n.connected[nodeID] = false
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("machine_key", machineKey.ShortString()).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Int("open_chans", len(n.nodes)).
|
Int("open_chans", len(n.nodes)).
|
||||||
Msg("Removed channel")
|
Msg("Removed channel")
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsConnected reports if a node is connected to headscale and has a
|
// IsConnected reports if a node is connected to headscale and has a
|
||||||
// poll session open.
|
// poll session open.
|
||||||
func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
|
func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
|
||||||
n.l.RLock()
|
n.l.RLock()
|
||||||
defer n.l.RUnlock()
|
defer n.l.RUnlock()
|
||||||
|
|
||||||
return n.connected[machineKey]
|
return n.connected[nodeID]
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLikelyConnected reports if a node is connected to headscale and has a
|
||||||
|
// poll session open, but doesnt lock, so might be wrong.
|
||||||
|
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
||||||
|
return n.connected[nodeID]
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): This returns a pointer and can be dangerous.
|
// TODO(kradalby): This returns a pointer and can be dangerous.
|
||||||
func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool {
|
func (n *Notifier) ConnectedMap() types.NodeConnectedMap {
|
||||||
return n.connected
|
return n.connected
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,19 +93,23 @@ func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
|
||||||
func (n *Notifier) NotifyWithIgnore(
|
func (n *Notifier) NotifyWithIgnore(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
update types.StateUpdate,
|
update types.StateUpdate,
|
||||||
ignore ...string,
|
ignoreNodeIDs ...types.NodeID,
|
||||||
) {
|
) {
|
||||||
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
|
log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify")
|
||||||
defer log.Trace().
|
defer log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Interface("type", update.Type).
|
Str("type", update.Type.String()).
|
||||||
Msg("releasing lock, finished notifying")
|
Msg("releasing lock, finished notifying")
|
||||||
|
|
||||||
n.l.RLock()
|
n.l.RLock()
|
||||||
defer n.l.RUnlock()
|
defer n.l.RUnlock()
|
||||||
|
|
||||||
for key, c := range n.nodes {
|
if update.Type == types.StatePeerChangedPatch {
|
||||||
if util.IsStringInSlice(ignore, key) {
|
log.Trace().Interface("update", update).Interface("online", n.connected).Msg("PATCH UPDATE SENT")
|
||||||
|
}
|
||||||
|
|
||||||
|
for nodeID, c := range n.nodes {
|
||||||
|
if slices.Contains(ignoreNodeIDs, nodeID) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,17 +117,17 @@ func (n *Notifier) NotifyWithIgnore(
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(ctx.Err()).
|
Err(ctx.Err()).
|
||||||
Str("mkey", key).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Any("origin", ctx.Value("origin")).
|
Any("origin", ctx.Value("origin")).
|
||||||
Any("hostname", ctx.Value("hostname")).
|
Any("origin-hostname", ctx.Value("hostname")).
|
||||||
Msgf("update not sent, context cancelled")
|
Msgf("update not sent, context cancelled")
|
||||||
|
|
||||||
return
|
return
|
||||||
case c <- update:
|
case c <- update:
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("mkey", key).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Any("origin", ctx.Value("origin")).
|
Any("origin", ctx.Value("origin")).
|
||||||
Any("hostname", ctx.Value("hostname")).
|
Any("origin-hostname", ctx.Value("hostname")).
|
||||||
Msgf("update successfully sent on chan")
|
Msgf("update successfully sent on chan")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -127,33 +136,33 @@ func (n *Notifier) NotifyWithIgnore(
|
||||||
func (n *Notifier) NotifyByMachineKey(
|
func (n *Notifier) NotifyByMachineKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
update types.StateUpdate,
|
update types.StateUpdate,
|
||||||
mKey key.MachinePublic,
|
nodeID types.NodeID,
|
||||||
) {
|
) {
|
||||||
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
|
log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify")
|
||||||
defer log.Trace().
|
defer log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Interface("type", update.Type).
|
Str("type", update.Type.String()).
|
||||||
Msg("releasing lock, finished notifying")
|
Msg("releasing lock, finished notifying")
|
||||||
|
|
||||||
n.l.RLock()
|
n.l.RLock()
|
||||||
defer n.l.RUnlock()
|
defer n.l.RUnlock()
|
||||||
|
|
||||||
if c, ok := n.nodes[mKey.String()]; ok {
|
if c, ok := n.nodes[nodeID]; ok {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(ctx.Err()).
|
Err(ctx.Err()).
|
||||||
Str("mkey", mKey.String()).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Any("origin", ctx.Value("origin")).
|
Any("origin", ctx.Value("origin")).
|
||||||
Any("hostname", ctx.Value("hostname")).
|
Any("origin-hostname", ctx.Value("hostname")).
|
||||||
Msgf("update not sent, context cancelled")
|
Msgf("update not sent, context cancelled")
|
||||||
|
|
||||||
return
|
return
|
||||||
case c <- update:
|
case c <- update:
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("mkey", mKey.String()).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Any("origin", ctx.Value("origin")).
|
Any("origin", ctx.Value("origin")).
|
||||||
Any("hostname", ctx.Value("hostname")).
|
Any("origin-hostname", ctx.Value("hostname")).
|
||||||
Msgf("update successfully sent on chan")
|
Msgf("update successfully sent on chan")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -166,7 +175,7 @@ func (n *Notifier) String() string {
|
||||||
str := []string{"Notifier, in map:\n"}
|
str := []string{"Notifier, in map:\n"}
|
||||||
|
|
||||||
for k, v := range n.nodes {
|
for k, v := range n.nodes {
|
||||||
str = append(str, fmt.Sprintf("\t%s: %v\n", k, v))
|
str = append(str, fmt.Sprintf("\t%d: %v\n", k, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(str, "")
|
return strings.Join(str, "")
|
||||||
|
|
|
@ -537,11 +537,8 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
||||||
util.LogErr(err, "Failed to write response")
|
util.LogErr(err, "Failed to write response")
|
||||||
}
|
}
|
||||||
|
|
||||||
stateUpdate := types.StateUpdateExpire(node.ID, expiry)
|
ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na")
|
||||||
if stateUpdate.Valid() {
|
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
|
||||||
ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na")
|
|
||||||
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, true, nil
|
return nil, true, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,7 +114,7 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
|
||||||
return &policy, nil
|
return &policy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateFilterAndSSHRules(
|
func GenerateFilterAndSSHRulesForTests(
|
||||||
policy *ACLPolicy,
|
policy *ACLPolicy,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
|
@ -124,40 +124,31 @@ func GenerateFilterAndSSHRules(
|
||||||
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
|
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := policy.generateFilterRules(node, peers)
|
rules, err := policy.CompileFilterRules(append(peers, node))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
|
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
|
||||||
|
|
||||||
var sshPolicy *tailcfg.SSHPolicy
|
sshPolicy, err := policy.CompileSSHPolicy(node, peers)
|
||||||
sshRules, err := policy.generateSSHRules(node, peers)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Interface("SSH", sshRules).
|
|
||||||
Str("node", node.GivenName).
|
|
||||||
Msg("SSH rules")
|
|
||||||
|
|
||||||
if sshPolicy == nil {
|
|
||||||
sshPolicy = &tailcfg.SSHPolicy{}
|
|
||||||
}
|
|
||||||
sshPolicy.Rules = sshRules
|
|
||||||
|
|
||||||
return rules, sshPolicy, nil
|
return rules, sshPolicy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateFilterRules takes a set of nodes and an ACLPolicy and generates a
|
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||||
func (pol *ACLPolicy) generateFilterRules(
|
func (pol *ACLPolicy) CompileFilterRules(
|
||||||
node *types.Node,
|
nodes types.Nodes,
|
||||||
peers types.Nodes,
|
|
||||||
) ([]tailcfg.FilterRule, error) {
|
) ([]tailcfg.FilterRule, error) {
|
||||||
|
if pol == nil {
|
||||||
|
return tailcfg.FilterAllowAll, nil
|
||||||
|
}
|
||||||
|
|
||||||
rules := []tailcfg.FilterRule{}
|
rules := []tailcfg.FilterRule{}
|
||||||
nodes := append(peers, node)
|
|
||||||
|
|
||||||
for index, acl := range pol.ACLs {
|
for index, acl := range pol.ACLs {
|
||||||
if acl.Action != "accept" {
|
if acl.Action != "accept" {
|
||||||
|
@ -279,10 +270,14 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pol *ACLPolicy) generateSSHRules(
|
func (pol *ACLPolicy) CompileSSHPolicy(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
) ([]*tailcfg.SSHRule, error) {
|
) (*tailcfg.SSHPolicy, error) {
|
||||||
|
if pol == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
rules := []*tailcfg.SSHRule{}
|
rules := []*tailcfg.SSHRule{}
|
||||||
|
|
||||||
acceptAction := tailcfg.SSHAction{
|
acceptAction := tailcfg.SSHAction{
|
||||||
|
@ -393,7 +388,9 @@ func (pol *ACLPolicy) generateSSHRules(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return rules, nil
|
return &tailcfg.SSHPolicy{
|
||||||
|
Rules: rules,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
|
func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
|
||||||
|
|
|
@ -385,11 +385,12 @@ acls:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := pol.generateFilterRules(&types.Node{
|
rules, err := pol.CompileFilterRules(types.Nodes{
|
||||||
IPAddresses: types.NodeAddresses{
|
&types.Node{
|
||||||
netip.MustParseAddr("100.100.100.100"),
|
IPAddresses: types.NodeAddresses{
|
||||||
|
netip.MustParseAddr("100.100.100.100"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}, types.Nodes{
|
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPAddresses: types.NodeAddresses{
|
IPAddresses: types.NodeAddresses{
|
||||||
netip.MustParseAddr("200.200.200.200"),
|
netip.MustParseAddr("200.200.200.200"),
|
||||||
|
@ -546,7 +547,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
|
||||||
c.Assert(pol.ACLs, check.HasLen, 6)
|
c.Assert(pol.ACLs, check.HasLen, 6)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
rules, err := pol.generateFilterRules(&types.Node{}, types.Nodes{})
|
rules, err := pol.CompileFilterRules(types.Nodes{})
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
c.Assert(rules, check.IsNil)
|
c.Assert(rules, check.IsNil)
|
||||||
}
|
}
|
||||||
|
@ -562,7 +563,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{})
|
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
|
||||||
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
|
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -581,7 +582,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{})
|
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
|
||||||
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
|
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -597,7 +598,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{})
|
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
|
||||||
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
|
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1724,8 +1725,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
|
||||||
pol ACLPolicy
|
pol ACLPolicy
|
||||||
}
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
node *types.Node
|
nodes types.Nodes
|
||||||
peers types.Nodes
|
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -1755,13 +1755,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
node: &types.Node{
|
nodes: types.Nodes{
|
||||||
IPAddresses: types.NodeAddresses{
|
&types.Node{
|
||||||
netip.MustParseAddr("100.64.0.1"),
|
IPAddresses: types.NodeAddresses{
|
||||||
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
netip.MustParseAddr("100.64.0.1"),
|
||||||
|
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
peers: types.Nodes{},
|
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
{
|
{
|
||||||
|
@ -1800,14 +1801,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
node: &types.Node{
|
nodes: types.Nodes{
|
||||||
IPAddresses: types.NodeAddresses{
|
&types.Node{
|
||||||
netip.MustParseAddr("100.64.0.1"),
|
IPAddresses: types.NodeAddresses{
|
||||||
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
netip.MustParseAddr("100.64.0.1"),
|
||||||
|
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
||||||
|
},
|
||||||
|
User: types.User{Name: "mickael"},
|
||||||
},
|
},
|
||||||
User: types.User{Name: "mickael"},
|
|
||||||
},
|
|
||||||
peers: types.Nodes{
|
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPAddresses: types.NodeAddresses{
|
IPAddresses: types.NodeAddresses{
|
||||||
netip.MustParseAddr("100.64.0.2"),
|
netip.MustParseAddr("100.64.0.2"),
|
||||||
|
@ -1846,9 +1847,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := tt.field.pol.generateFilterRules(
|
got, err := tt.field.pol.CompileFilterRules(
|
||||||
tt.args.node,
|
tt.args.nodes,
|
||||||
tt.args.peers,
|
|
||||||
)
|
)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -1980,9 +1980,8 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
rules, _ := tt.pol.generateFilterRules(
|
rules, _ := tt.pol.CompileFilterRules(
|
||||||
tt.node,
|
append(tt.peers, tt.node),
|
||||||
tt.peers,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
got := ReduceFilterRules(tt.node, rules)
|
got := ReduceFilterRules(tt.node, rules)
|
||||||
|
@ -2883,7 +2882,7 @@ func TestSSHRules(t *testing.T) {
|
||||||
node types.Node
|
node types.Node
|
||||||
peers types.Nodes
|
peers types.Nodes
|
||||||
pol ACLPolicy
|
pol ACLPolicy
|
||||||
want []*tailcfg.SSHRule
|
want *tailcfg.SSHPolicy
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "peers-can-connect",
|
name: "peers-can-connect",
|
||||||
|
@ -2946,7 +2945,7 @@ func TestSSHRules(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []*tailcfg.SSHRule{
|
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||||
{
|
{
|
||||||
Principals: []*tailcfg.SSHPrincipal{
|
Principals: []*tailcfg.SSHPrincipal{
|
||||||
{
|
{
|
||||||
|
@ -2991,7 +2990,7 @@ func TestSSHRules(t *testing.T) {
|
||||||
},
|
},
|
||||||
Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true},
|
Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true},
|
||||||
},
|
},
|
||||||
},
|
}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "peers-cannot-connect",
|
name: "peers-cannot-connect",
|
||||||
|
@ -3042,13 +3041,13 @@ func TestSSHRules(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []*tailcfg.SSHRule{},
|
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := tt.pol.generateSSHRules(&tt.node, tt.peers)
|
got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
|
@ -3155,7 +3154,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{})
|
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
@ -3206,7 +3205,7 @@ func TestInvalidTagValidUser(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{})
|
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
@ -3265,7 +3264,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
|
||||||
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||||
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||||
|
|
||||||
got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{})
|
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
@ -3335,7 +3334,7 @@ func TestValidTagInvalidUser(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{nodes2})
|
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
|
1112
hscontrol/poll.go
1112
hscontrol/poll.go
File diff suppressed because it is too large
Load diff
|
@ -1,96 +0,0 @@
|
||||||
package hscontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
"tailscale.com/types/key"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
MinimumCapVersion tailcfg.CapabilityVersion = 58
|
|
||||||
)
|
|
||||||
|
|
||||||
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
|
|
||||||
//
|
|
||||||
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
|
|
||||||
// the clients when something in the network changes.
|
|
||||||
//
|
|
||||||
// The clients POST stuff like HostInfo and their Endpoints here, but
|
|
||||||
// only after their first request (marked with the ReadOnly field).
|
|
||||||
//
|
|
||||||
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
|
|
||||||
func (ns *noiseServer) NoisePollNetMapHandler(
|
|
||||||
writer http.ResponseWriter,
|
|
||||||
req *http.Request,
|
|
||||||
) {
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "NoisePollNetMap").
|
|
||||||
Msg("PollNetMapHandler called")
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Any("headers", req.Header).
|
|
||||||
Caller().
|
|
||||||
Msg("Headers")
|
|
||||||
|
|
||||||
body, _ := io.ReadAll(req.Body)
|
|
||||||
|
|
||||||
mapRequest := tailcfg.MapRequest{}
|
|
||||||
if err := json.Unmarshal(body, &mapRequest); err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot parse MapRequest")
|
|
||||||
http.Error(writer, "Internal error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reject unsupported versions
|
|
||||||
if mapRequest.Version < MinimumCapVersion {
|
|
||||||
log.Info().
|
|
||||||
Caller().
|
|
||||||
Int("min_version", int(MinimumCapVersion)).
|
|
||||||
Int("client_version", int(mapRequest.Version)).
|
|
||||||
Msg("unsupported client connected")
|
|
||||||
http.Error(writer, "Internal error", http.StatusBadRequest)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ns.nodeKey = mapRequest.NodeKey
|
|
||||||
|
|
||||||
node, err := ns.headscale.db.GetNodeByAnyKey(
|
|
||||||
ns.conn.Peer(),
|
|
||||||
mapRequest.NodeKey,
|
|
||||||
key.NodePublic{},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
log.Warn().
|
|
||||||
Str("handler", "NoisePollNetMap").
|
|
||||||
Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String())
|
|
||||||
http.Error(writer, "Internal error", http.StatusNotFound)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "NoisePollNetMap").
|
|
||||||
Msgf("Failed to fetch node from the database with node key: %s", mapRequest.NodeKey.String())
|
|
||||||
http.Error(writer, "Internal error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Debug().
|
|
||||||
Str("handler", "NoisePollNetMap").
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Int("cap_ver", int(mapRequest.Version)).
|
|
||||||
Msg("A node sending a MapRequest with Noise protocol")
|
|
||||||
|
|
||||||
ns.headscale.handlePoll(writer, req.Context(), node, mapRequest)
|
|
||||||
}
|
|
|
@ -90,6 +90,25 @@ func (i StringList) Value() (driver.Value, error) {
|
||||||
|
|
||||||
type StateUpdateType int
|
type StateUpdateType int
|
||||||
|
|
||||||
|
func (su StateUpdateType) String() string {
|
||||||
|
switch su {
|
||||||
|
case StateFullUpdate:
|
||||||
|
return "StateFullUpdate"
|
||||||
|
case StatePeerChanged:
|
||||||
|
return "StatePeerChanged"
|
||||||
|
case StatePeerChangedPatch:
|
||||||
|
return "StatePeerChangedPatch"
|
||||||
|
case StatePeerRemoved:
|
||||||
|
return "StatePeerRemoved"
|
||||||
|
case StateSelfUpdate:
|
||||||
|
return "StateSelfUpdate"
|
||||||
|
case StateDERPUpdated:
|
||||||
|
return "StateDERPUpdated"
|
||||||
|
}
|
||||||
|
|
||||||
|
return "unknown state update type"
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
StateFullUpdate StateUpdateType = iota
|
StateFullUpdate StateUpdateType = iota
|
||||||
// StatePeerChanged is used for updates that needs
|
// StatePeerChanged is used for updates that needs
|
||||||
|
@ -118,7 +137,7 @@ type StateUpdate struct {
|
||||||
// ChangeNodes must be set when Type is StatePeerAdded
|
// ChangeNodes must be set when Type is StatePeerAdded
|
||||||
// and StatePeerChanged and contains the full node
|
// and StatePeerChanged and contains the full node
|
||||||
// object for added nodes.
|
// object for added nodes.
|
||||||
ChangeNodes Nodes
|
ChangeNodes []NodeID
|
||||||
|
|
||||||
// ChangePatches must be set when Type is StatePeerChangedPatch
|
// ChangePatches must be set when Type is StatePeerChangedPatch
|
||||||
// and contains a populated PeerChange object.
|
// and contains a populated PeerChange object.
|
||||||
|
@ -127,7 +146,7 @@ type StateUpdate struct {
|
||||||
// Removed must be set when Type is StatePeerRemoved and
|
// Removed must be set when Type is StatePeerRemoved and
|
||||||
// contain a list of the nodes that has been removed from
|
// contain a list of the nodes that has been removed from
|
||||||
// the network.
|
// the network.
|
||||||
Removed []tailcfg.NodeID
|
Removed []NodeID
|
||||||
|
|
||||||
// DERPMap must be set when Type is StateDERPUpdated and
|
// DERPMap must be set when Type is StateDERPUpdated and
|
||||||
// contain the new DERP Map.
|
// contain the new DERP Map.
|
||||||
|
@ -138,39 +157,6 @@ type StateUpdate struct {
|
||||||
Message string
|
Message string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Valid reports if a StateUpdate is correctly filled and
|
|
||||||
// panics if the mandatory fields for a type is not
|
|
||||||
// filled.
|
|
||||||
// Reports true if valid.
|
|
||||||
func (su *StateUpdate) Valid() bool {
|
|
||||||
switch su.Type {
|
|
||||||
case StatePeerChanged:
|
|
||||||
if su.ChangeNodes == nil {
|
|
||||||
panic("Mandatory field ChangeNodes is not set on StatePeerChanged update")
|
|
||||||
}
|
|
||||||
case StatePeerChangedPatch:
|
|
||||||
if su.ChangePatches == nil {
|
|
||||||
panic("Mandatory field ChangePatches is not set on StatePeerChangedPatch update")
|
|
||||||
}
|
|
||||||
case StatePeerRemoved:
|
|
||||||
if su.Removed == nil {
|
|
||||||
panic("Mandatory field Removed is not set on StatePeerRemove update")
|
|
||||||
}
|
|
||||||
case StateSelfUpdate:
|
|
||||||
if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 {
|
|
||||||
panic(
|
|
||||||
"Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
case StateDERPUpdated:
|
|
||||||
if su.DERPMap == nil {
|
|
||||||
panic("Mandatory field DERPMap is not set on StateDERPUpdated update")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Empty reports if there are any updates in the StateUpdate.
|
// Empty reports if there are any updates in the StateUpdate.
|
||||||
func (su *StateUpdate) Empty() bool {
|
func (su *StateUpdate) Empty() bool {
|
||||||
switch su.Type {
|
switch su.Type {
|
||||||
|
@ -185,12 +171,12 @@ func (su *StateUpdate) Empty() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate {
|
func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
|
||||||
return StateUpdate{
|
return StateUpdate{
|
||||||
Type: StatePeerChangedPatch,
|
Type: StatePeerChangedPatch,
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
ChangePatches: []*tailcfg.PeerChange{
|
||||||
{
|
{
|
||||||
NodeID: tailcfg.NodeID(nodeID),
|
NodeID: nodeID.NodeID(),
|
||||||
KeyExpiry: &expiry,
|
KeyExpiry: &expiry,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -69,6 +69,8 @@ type Config struct {
|
||||||
CLI CLIConfig
|
CLI CLIConfig
|
||||||
|
|
||||||
ACL ACLConfig
|
ACL ACLConfig
|
||||||
|
|
||||||
|
Tuning Tuning
|
||||||
}
|
}
|
||||||
|
|
||||||
type SqliteConfig struct {
|
type SqliteConfig struct {
|
||||||
|
@ -161,6 +163,11 @@ type LogConfig struct {
|
||||||
Level zerolog.Level
|
Level zerolog.Level
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Tuning struct {
|
||||||
|
BatchChangeDelay time.Duration
|
||||||
|
NodeMapSessionBufferedChanSize int
|
||||||
|
}
|
||||||
|
|
||||||
func LoadConfig(path string, isFile bool) error {
|
func LoadConfig(path string, isFile bool) error {
|
||||||
if isFile {
|
if isFile {
|
||||||
viper.SetConfigFile(path)
|
viper.SetConfigFile(path)
|
||||||
|
@ -220,6 +227,9 @@ func LoadConfig(path string, isFile bool) error {
|
||||||
|
|
||||||
viper.SetDefault("node_update_check_interval", "10s")
|
viper.SetDefault("node_update_check_interval", "10s")
|
||||||
|
|
||||||
|
viper.SetDefault("tuning.batch_change_delay", "800ms")
|
||||||
|
viper.SetDefault("tuning.node_mapsession_buffered_chan_size", 30)
|
||||||
|
|
||||||
if IsCLIConfigured() {
|
if IsCLIConfigured() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -719,6 +729,12 @@ func GetHeadscaleConfig() (*Config, error) {
|
||||||
},
|
},
|
||||||
|
|
||||||
Log: GetLogConfig(),
|
Log: GetLogConfig(),
|
||||||
|
|
||||||
|
// TODO(kradalby): Document these settings when more stable
|
||||||
|
Tuning: Tuning{
|
||||||
|
BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"),
|
||||||
|
NodeMapSessionBufferedChanSize: viper.GetInt("tuning.node_mapsession_buffered_chan_size"),
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,11 +7,13 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
@ -27,9 +29,24 @@ var (
|
||||||
ErrNodeUserHasNoName = errors.New("node user has no name")
|
ErrNodeUserHasNoName = errors.New("node user has no name")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type NodeID uint64
|
||||||
|
type NodeConnectedMap map[NodeID]bool
|
||||||
|
|
||||||
|
func (id NodeID) StableID() tailcfg.StableNodeID {
|
||||||
|
return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (id NodeID) NodeID() tailcfg.NodeID {
|
||||||
|
return tailcfg.NodeID(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (id NodeID) Uint64() uint64 {
|
||||||
|
return uint64(id)
|
||||||
|
}
|
||||||
|
|
||||||
// Node is a Headscale client.
|
// Node is a Headscale client.
|
||||||
type Node struct {
|
type Node struct {
|
||||||
ID uint64 `gorm:"primary_key"`
|
ID NodeID `gorm:"primary_key"`
|
||||||
|
|
||||||
// MachineKeyDatabaseField is the string representation of MachineKey
|
// MachineKeyDatabaseField is the string representation of MachineKey
|
||||||
// it is _only_ used for reading and writing the key to the
|
// it is _only_ used for reading and writing the key to the
|
||||||
|
@ -198,7 +215,7 @@ func (node Node) IsExpired() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return time.Now().UTC().After(*node.Expiry)
|
return time.Since(*node.Expiry) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEphemeral returns if the node is registered as an Ephemeral node.
|
// IsEphemeral returns if the node is registered as an Ephemeral node.
|
||||||
|
@ -319,7 +336,7 @@ func (node *Node) AfterFind(tx *gorm.DB) error {
|
||||||
|
|
||||||
func (node *Node) Proto() *v1.Node {
|
func (node *Node) Proto() *v1.Node {
|
||||||
nodeProto := &v1.Node{
|
nodeProto := &v1.Node{
|
||||||
Id: node.ID,
|
Id: uint64(node.ID),
|
||||||
MachineKey: node.MachineKey.String(),
|
MachineKey: node.MachineKey.String(),
|
||||||
|
|
||||||
NodeKey: node.NodeKey.String(),
|
NodeKey: node.NodeKey.String(),
|
||||||
|
@ -486,8 +503,8 @@ func (nodes Nodes) String() string {
|
||||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nodes Nodes) IDMap() map[uint64]*Node {
|
func (nodes Nodes) IDMap() map[NodeID]*Node {
|
||||||
ret := map[uint64]*Node{}
|
ret := map[NodeID]*Node{}
|
||||||
|
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
ret[node.ID] = node
|
ret[node.ID] = node
|
||||||
|
|
|
@ -83,7 +83,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
|
@ -142,7 +142,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
|
|
|
@ -53,7 +53,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
|
@ -92,7 +92,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
|
|
|
@ -65,7 +65,7 @@ func TestPingAllByIP(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
|
@ -103,7 +103,7 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
|
@ -135,7 +135,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
||||||
for _, client := range allClients {
|
for _, client := range allClients {
|
||||||
|
@ -176,7 +176,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
allClients, err = scenario.ListTailscaleClients()
|
allClients, err = scenario.ListTailscaleClients()
|
||||||
assertNoErrListClients(t, err)
|
assertNoErrListClients(t, err)
|
||||||
|
@ -329,7 +329,7 @@ func TestPingAllByHostname(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
||||||
assertNoErrListFQDN(t, err)
|
assertNoErrListFQDN(t, err)
|
||||||
|
@ -539,7 +539,7 @@ func TestResolveMagicDNS(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
// Poor mans cache
|
// Poor mans cache
|
||||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||||
|
@ -609,7 +609,7 @@ func TestExpireNode(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
|
@ -711,7 +711,7 @@ func TestExpireNode(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
func TestNodeOnlineStatus(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -723,7 +723,7 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
||||||
"user1": len(MustTestVersions),
|
"user1": len(MustTestVersions),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("onlinelastseen"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("online"))
|
||||||
assertNoErrHeadscaleEnv(t, err)
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
allClients, err := scenario.ListTailscaleClients()
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
|
@ -735,7 +735,7 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
assertClientsState(t, allClients)
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
|
@ -755,8 +755,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
keepAliveInterval := 60 * time.Second
|
|
||||||
|
|
||||||
// Duration is chosen arbitrarily, 10m is reported in #1561
|
// Duration is chosen arbitrarily, 10m is reported in #1561
|
||||||
testDuration := 12 * time.Minute
|
testDuration := 12 * time.Minute
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
@ -780,11 +778,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
||||||
err = json.Unmarshal([]byte(result), &nodes)
|
err = json.Unmarshal([]byte(result), &nodes)
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
// Threshold with some leeway
|
|
||||||
lastSeenThreshold := now.Add(-keepAliveInterval - (10 * time.Second))
|
|
||||||
|
|
||||||
// Verify that headscale reports the nodes as online
|
// Verify that headscale reports the nodes as online
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
// All nodes should be online
|
// All nodes should be online
|
||||||
|
@ -795,18 +788,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
||||||
node.GetName(),
|
node.GetName(),
|
||||||
time.Since(start),
|
time.Since(start),
|
||||||
)
|
)
|
||||||
|
|
||||||
lastSeen := node.GetLastSeen().AsTime()
|
|
||||||
// All nodes should have been last seen between now and the keepAliveInterval
|
|
||||||
assert.Truef(
|
|
||||||
t,
|
|
||||||
lastSeen.After(lastSeenThreshold),
|
|
||||||
"node (%s) lastSeen (%v) was not %s after the threshold (%v)",
|
|
||||||
node.GetName(),
|
|
||||||
lastSeen,
|
|
||||||
keepAliveInterval,
|
|
||||||
lastSeenThreshold,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that all nodes report all nodes to be online
|
// Verify that all nodes report all nodes to be online
|
||||||
|
@ -834,15 +815,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
||||||
client.Hostname(),
|
client.Hostname(),
|
||||||
time.Since(start),
|
time.Since(start),
|
||||||
)
|
)
|
||||||
|
|
||||||
// from docs: last seen to tailcontrol; only present if offline
|
|
||||||
// assert.Nilf(
|
|
||||||
// t,
|
|
||||||
// peerStatus.LastSeen,
|
|
||||||
// "expected node %s to not have LastSeen set, got %s",
|
|
||||||
// peerStatus.HostName,
|
|
||||||
// peerStatus.LastSeen,
|
|
||||||
// )
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -850,3 +822,87 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPingAllByIPManyUpDown is a variant of the PingAll
|
||||||
|
// test which will take the tailscale node up and down
|
||||||
|
// five times ensuring they are able to restablish connectivity.
|
||||||
|
func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
scenario, err := NewScenario()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
// TODO(kradalby): it does not look like the user thing works, only second
|
||||||
|
// get created? maybe only when many?
|
||||||
|
spec := map[string]int{
|
||||||
|
"user1": len(MustTestVersions),
|
||||||
|
"user2": len(MustTestVersions),
|
||||||
|
}
|
||||||
|
|
||||||
|
headscaleConfig := map[string]string{
|
||||||
|
"HEADSCALE_DERP_URLS": "",
|
||||||
|
"HEADSCALE_DERP_SERVER_ENABLED": "true",
|
||||||
|
"HEADSCALE_DERP_SERVER_REGION_ID": "999",
|
||||||
|
"HEADSCALE_DERP_SERVER_REGION_CODE": "headscale",
|
||||||
|
"HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP",
|
||||||
|
"HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478",
|
||||||
|
"HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key",
|
||||||
|
|
||||||
|
// Envknob for enabling DERP debug logs
|
||||||
|
"DERP_DEBUG_LOGS": "true",
|
||||||
|
"DERP_PROBER_DEBUG_LOGS": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(spec,
|
||||||
|
[]tsic.Option{},
|
||||||
|
hsic.WithTestName("pingallbyip"),
|
||||||
|
hsic.WithConfigEnv(headscaleConfig),
|
||||||
|
hsic.WithTLS(),
|
||||||
|
hsic.WithHostnameAsServerURL(),
|
||||||
|
)
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
|
assertNoErrListClients(t, err)
|
||||||
|
|
||||||
|
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||||
|
assertNoErrListClientIPs(t, err)
|
||||||
|
|
||||||
|
err = scenario.WaitForTailscaleSync()
|
||||||
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
|
return x.String()
|
||||||
|
})
|
||||||
|
|
||||||
|
success := pingAllHelper(t, allClients, allAddrs)
|
||||||
|
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||||
|
|
||||||
|
for run := range 3 {
|
||||||
|
t.Logf("Starting DownUpPing run %d", run+1)
|
||||||
|
|
||||||
|
for _, client := range allClients {
|
||||||
|
t.Logf("taking down %q", client.Hostname())
|
||||||
|
client.Down()
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
for _, client := range allClients {
|
||||||
|
t.Logf("bringing up %q", client.Hostname())
|
||||||
|
client.Up()
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
err = scenario.WaitForTailscaleSync()
|
||||||
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
success := pingAllHelper(t, allClients, allAddrs)
|
||||||
|
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -212,7 +212,11 @@ func TestEnablingRoutes(t *testing.T) {
|
||||||
|
|
||||||
if route.GetId() == routeToBeDisabled.GetId() {
|
if route.GetId() == routeToBeDisabled.GetId() {
|
||||||
assert.Equal(t, false, route.GetEnabled())
|
assert.Equal(t, false, route.GetEnabled())
|
||||||
assert.Equal(t, false, route.GetIsPrimary())
|
|
||||||
|
// since this is the only route of this cidr,
|
||||||
|
// it will not failover, and remain Primary
|
||||||
|
// until something can replace it.
|
||||||
|
assert.Equal(t, true, route.GetIsPrimary())
|
||||||
} else {
|
} else {
|
||||||
assert.Equal(t, true, route.GetEnabled())
|
assert.Equal(t, true, route.GetEnabled())
|
||||||
assert.Equal(t, true, route.GetIsPrimary())
|
assert.Equal(t, true, route.GetIsPrimary())
|
||||||
|
@ -291,6 +295,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
|
|
||||||
client := allClients[2]
|
client := allClients[2]
|
||||||
|
|
||||||
|
t.Logf("Advertise route from r1 (%s) and r2 (%s), making it HA, n1 is primary", subRouter1.Hostname(), subRouter2.Hostname())
|
||||||
// advertise HA route on node 1 and 2
|
// advertise HA route on node 1 and 2
|
||||||
// ID 1 will be primary
|
// ID 1 will be primary
|
||||||
// ID 2 will be secondary
|
// ID 2 will be secondary
|
||||||
|
@ -384,12 +389,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
// Node 1 is primary
|
// Node 1 is primary
|
||||||
assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
|
assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
|
||||||
assert.Equal(t, true, enablingRoutes[0].GetEnabled())
|
assert.Equal(t, true, enablingRoutes[0].GetEnabled())
|
||||||
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary())
|
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be primary")
|
||||||
|
|
||||||
// Node 2 is not primary
|
// Node 2 is not primary
|
||||||
assert.Equal(t, true, enablingRoutes[1].GetAdvertised())
|
assert.Equal(t, true, enablingRoutes[1].GetAdvertised())
|
||||||
assert.Equal(t, true, enablingRoutes[1].GetEnabled())
|
assert.Equal(t, true, enablingRoutes[1].GetEnabled())
|
||||||
assert.Equal(t, false, enablingRoutes[1].GetIsPrimary())
|
assert.Equal(t, false, enablingRoutes[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be non-primary")
|
||||||
|
|
||||||
// Verify that the client has routes from the primary machine
|
// Verify that the client has routes from the primary machine
|
||||||
srs1, err := subRouter1.Status()
|
srs1, err := subRouter1.Status()
|
||||||
|
@ -401,6 +406,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
|
srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
|
||||||
srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey]
|
srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey]
|
||||||
|
|
||||||
|
assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up")
|
||||||
|
assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up")
|
||||||
|
|
||||||
assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
|
assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
|
||||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||||
|
|
||||||
|
@ -411,7 +419,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
)
|
)
|
||||||
|
|
||||||
// Take down the current primary
|
// Take down the current primary
|
||||||
t.Logf("taking down subnet router 1 (%s)", subRouter1.Hostname())
|
t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname())
|
||||||
|
t.Logf("expecting r2 (%s) to take over as primary", subRouter2.Hostname())
|
||||||
err = subRouter1.Down()
|
err = subRouter1.Down()
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
@ -435,15 +444,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
// Node 1 is not primary
|
// Node 1 is not primary
|
||||||
assert.Equal(t, true, routesAfterMove[0].GetAdvertised())
|
assert.Equal(t, true, routesAfterMove[0].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterMove[0].GetEnabled())
|
assert.Equal(t, true, routesAfterMove[0].GetEnabled())
|
||||||
assert.Equal(t, false, routesAfterMove[0].GetIsPrimary())
|
assert.Equal(t, false, routesAfterMove[0].GetIsPrimary(), "r1 is down, expected r2 to be primary")
|
||||||
|
|
||||||
// Node 2 is primary
|
// Node 2 is primary
|
||||||
assert.Equal(t, true, routesAfterMove[1].GetAdvertised())
|
assert.Equal(t, true, routesAfterMove[1].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterMove[1].GetEnabled())
|
assert.Equal(t, true, routesAfterMove[1].GetEnabled())
|
||||||
assert.Equal(t, true, routesAfterMove[1].GetIsPrimary())
|
assert.Equal(t, true, routesAfterMove[1].GetIsPrimary(), "r1 is down, expected r2 to be primary")
|
||||||
|
|
||||||
// TODO(kradalby): Check client status
|
|
||||||
// Route is expected to be on SR2
|
|
||||||
|
|
||||||
srs2, err = subRouter2.Status()
|
srs2, err = subRouter2.Status()
|
||||||
|
|
||||||
|
@ -453,6 +459,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||||
|
|
||||||
|
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
|
||||||
|
assert.True(t, srs2PeerStatus.Online, "r1 down, r2 up")
|
||||||
|
|
||||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||||
assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
|
assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
|
||||||
|
|
||||||
|
@ -465,7 +474,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Take down subnet router 2, leaving none available
|
// Take down subnet router 2, leaving none available
|
||||||
t.Logf("taking down subnet router 2 (%s)", subRouter2.Hostname())
|
t.Logf("taking down subnet router r2 (%s)", subRouter2.Hostname())
|
||||||
|
t.Logf("expecting r2 (%s) to remain primary, no other available", subRouter2.Hostname())
|
||||||
err = subRouter2.Down()
|
err = subRouter2.Down()
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
@ -489,14 +499,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
// Node 1 is not primary
|
// Node 1 is not primary
|
||||||
assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised())
|
assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterBothDown[0].GetEnabled())
|
assert.Equal(t, true, routesAfterBothDown[0].GetEnabled())
|
||||||
assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary())
|
assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
|
||||||
|
|
||||||
// Node 2 is primary
|
// Node 2 is primary
|
||||||
// if the node goes down, but no other suitable route is
|
// if the node goes down, but no other suitable route is
|
||||||
// available, keep the last known good route.
|
// available, keep the last known good route.
|
||||||
assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised())
|
assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterBothDown[1].GetEnabled())
|
assert.Equal(t, true, routesAfterBothDown[1].GetEnabled())
|
||||||
assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary())
|
assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
|
||||||
|
|
||||||
// TODO(kradalby): Check client status
|
// TODO(kradalby): Check client status
|
||||||
// Both are expected to be down
|
// Both are expected to be down
|
||||||
|
@ -508,6 +518,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||||
|
|
||||||
|
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
|
||||||
|
assert.False(t, srs2PeerStatus.Online, "r1 down, r2 down")
|
||||||
|
|
||||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||||
assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
|
assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
|
||||||
|
|
||||||
|
@ -520,7 +533,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bring up subnet router 1, making the route available from there.
|
// Bring up subnet router 1, making the route available from there.
|
||||||
t.Logf("bringing up subnet router 1 (%s)", subRouter1.Hostname())
|
t.Logf("bringing up subnet router r1 (%s)", subRouter1.Hostname())
|
||||||
|
t.Logf("expecting r1 (%s) to take over as primary (only one online)", subRouter1.Hostname())
|
||||||
err = subRouter1.Up()
|
err = subRouter1.Up()
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
@ -544,12 +558,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
// Node 1 is primary
|
// Node 1 is primary
|
||||||
assert.Equal(t, true, routesAfter1Up[0].GetAdvertised())
|
assert.Equal(t, true, routesAfter1Up[0].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfter1Up[0].GetEnabled())
|
assert.Equal(t, true, routesAfter1Up[0].GetEnabled())
|
||||||
assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary())
|
assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
|
||||||
|
|
||||||
// Node 2 is not primary
|
// Node 2 is not primary
|
||||||
assert.Equal(t, true, routesAfter1Up[1].GetAdvertised())
|
assert.Equal(t, true, routesAfter1Up[1].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfter1Up[1].GetEnabled())
|
assert.Equal(t, true, routesAfter1Up[1].GetEnabled())
|
||||||
assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary())
|
assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
|
@ -558,6 +572,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||||
|
|
||||||
|
assert.True(t, srs1PeerStatus.Online, "r1 is back up, r2 down")
|
||||||
|
assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down")
|
||||||
|
|
||||||
assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
|
assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
|
||||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||||
|
|
||||||
|
@ -570,7 +587,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bring up subnet router 2, should result in no change.
|
// Bring up subnet router 2, should result in no change.
|
||||||
t.Logf("bringing up subnet router 2 (%s)", subRouter2.Hostname())
|
t.Logf("bringing up subnet router r2 (%s)", subRouter2.Hostname())
|
||||||
|
t.Logf("both online, expecting r1 (%s) to still be primary (no flapping)", subRouter1.Hostname())
|
||||||
err = subRouter2.Up()
|
err = subRouter2.Up()
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
@ -594,12 +612,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
// Node 1 is not primary
|
// Node 1 is not primary
|
||||||
assert.Equal(t, true, routesAfter2Up[0].GetAdvertised())
|
assert.Equal(t, true, routesAfter2Up[0].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfter2Up[0].GetEnabled())
|
assert.Equal(t, true, routesAfter2Up[0].GetEnabled())
|
||||||
assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary())
|
assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
|
||||||
|
|
||||||
// Node 2 is primary
|
// Node 2 is primary
|
||||||
assert.Equal(t, true, routesAfter2Up[1].GetAdvertised())
|
assert.Equal(t, true, routesAfter2Up[1].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfter2Up[1].GetEnabled())
|
assert.Equal(t, true, routesAfter2Up[1].GetEnabled())
|
||||||
assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary())
|
assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
|
@ -608,6 +626,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||||
|
|
||||||
|
assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up")
|
||||||
|
assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up")
|
||||||
|
|
||||||
assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
|
assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
|
||||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||||
|
|
||||||
|
@ -620,7 +641,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable the route of subnet router 1, making it failover to 2
|
// Disable the route of subnet router 1, making it failover to 2
|
||||||
t.Logf("disabling route in subnet router 1 (%s)", subRouter1.Hostname())
|
t.Logf("disabling route in subnet router r1 (%s)", subRouter1.Hostname())
|
||||||
|
t.Logf("expecting route to failover to r2 (%s), which is still available", subRouter2.Hostname())
|
||||||
_, err = headscale.Execute(
|
_, err = headscale.Execute(
|
||||||
[]string{
|
[]string{
|
||||||
"headscale",
|
"headscale",
|
||||||
|
@ -648,7 +670,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
assert.Len(t, routesAfterDisabling1, 2)
|
assert.Len(t, routesAfterDisabling1, 2)
|
||||||
|
|
||||||
t.Logf("routes after disabling1 %#v", routesAfterDisabling1)
|
t.Logf("routes after disabling r1 %#v", routesAfterDisabling1)
|
||||||
|
|
||||||
// Node 1 is not primary
|
// Node 1 is not primary
|
||||||
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
|
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
|
||||||
|
@ -680,6 +702,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
|
|
||||||
// enable the route of subnet router 1, no change expected
|
// enable the route of subnet router 1, no change expected
|
||||||
t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
|
t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
|
||||||
|
t.Logf("both online, expecting r2 (%s) to still be primary (no flapping)", subRouter2.Hostname())
|
||||||
_, err = headscale.Execute(
|
_, err = headscale.Execute(
|
||||||
[]string{
|
[]string{
|
||||||
"headscale",
|
"headscale",
|
||||||
|
@ -736,7 +759,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete the route of subnet router 2, failover to one expected
|
// delete the route of subnet router 2, failover to one expected
|
||||||
t.Logf("deleting route in subnet router 2 (%s)", subRouter2.Hostname())
|
t.Logf("deleting route in subnet router r2 (%s)", subRouter2.Hostname())
|
||||||
|
t.Logf("expecting route to failover to r1 (%s)", subRouter1.Hostname())
|
||||||
_, err = headscale.Execute(
|
_, err = headscale.Execute(
|
||||||
[]string{
|
[]string{
|
||||||
"headscale",
|
"headscale",
|
||||||
|
@ -764,7 +788,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
assert.Len(t, routesAfterDeleting2, 1)
|
assert.Len(t, routesAfterDeleting2, 1)
|
||||||
|
|
||||||
t.Logf("routes after deleting2 %#v", routesAfterDeleting2)
|
t.Logf("routes after deleting r2 %#v", routesAfterDeleting2)
|
||||||
|
|
||||||
// Node 1 is primary
|
// Node 1 is primary
|
||||||
assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised())
|
assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised())
|
||||||
|
|
|
@ -50,6 +50,8 @@ var (
|
||||||
tailscaleVersions2021 = map[string]bool{
|
tailscaleVersions2021 = map[string]bool{
|
||||||
"head": true,
|
"head": true,
|
||||||
"unstable": true,
|
"unstable": true,
|
||||||
|
"1.60": true, // CapVer: 82
|
||||||
|
"1.58": true, // CapVer: 82
|
||||||
"1.56": true, // CapVer: 82
|
"1.56": true, // CapVer: 82
|
||||||
"1.54": true, // CapVer: 79
|
"1.54": true, // CapVer: 79
|
||||||
"1.52": true, // CapVer: 79
|
"1.52": true, // CapVer: 79
|
||||||
|
|
|
@ -27,7 +27,7 @@ type TailscaleClient interface {
|
||||||
Down() error
|
Down() error
|
||||||
IPs() ([]netip.Addr, error)
|
IPs() ([]netip.Addr, error)
|
||||||
FQDN() (string, error)
|
FQDN() (string, error)
|
||||||
Status() (*ipnstate.Status, error)
|
Status(...bool) (*ipnstate.Status, error)
|
||||||
Netmap() (*netmap.NetworkMap, error)
|
Netmap() (*netmap.NetworkMap, error)
|
||||||
Netcheck() (*netcheck.Report, error)
|
Netcheck() (*netcheck.Report, error)
|
||||||
WaitForNeedsLogin() error
|
WaitForNeedsLogin() error
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -503,7 +504,7 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Status returns the ipnstate.Status of the Tailscale instance.
|
// Status returns the ipnstate.Status of the Tailscale instance.
|
||||||
func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) {
|
func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) {
|
||||||
command := []string{
|
command := []string{
|
||||||
"tailscale",
|
"tailscale",
|
||||||
"status",
|
"status",
|
||||||
|
@ -521,60 +522,70 @@ func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) {
|
||||||
return nil, fmt.Errorf("failed to unmarshal tailscale status: %w", err)
|
return nil, fmt.Errorf("failed to unmarshal tailscale status: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_status.json", t.hostname), []byte(result), 0o755)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("status netmap to /tmp/control: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &status, err
|
return &status, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
|
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
|
||||||
// Only works with Tailscale 1.56 and newer.
|
// Only works with Tailscale 1.56 and newer.
|
||||||
// Panics if version is lower then minimum.
|
// Panics if version is lower then minimum.
|
||||||
// func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
|
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
|
||||||
// if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
|
if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
|
||||||
// panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version))
|
panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version))
|
||||||
// }
|
}
|
||||||
|
|
||||||
// command := []string{
|
command := []string{
|
||||||
// "tailscale",
|
"tailscale",
|
||||||
// "debug",
|
"debug",
|
||||||
// "netmap",
|
"netmap",
|
||||||
// }
|
}
|
||||||
|
|
||||||
// result, stderr, err := t.Execute(command)
|
result, stderr, err := t.Execute(command)
|
||||||
// if err != nil {
|
if err != nil {
|
||||||
// fmt.Printf("stderr: %s\n", stderr)
|
fmt.Printf("stderr: %s\n", stderr)
|
||||||
// return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err)
|
return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err)
|
||||||
// }
|
}
|
||||||
|
|
||||||
// var nm netmap.NetworkMap
|
var nm netmap.NetworkMap
|
||||||
// err = json.Unmarshal([]byte(result), &nm)
|
err = json.Unmarshal([]byte(result), &nm)
|
||||||
// if err != nil {
|
if err != nil {
|
||||||
// return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err)
|
return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err)
|
||||||
// }
|
}
|
||||||
|
|
||||||
// return &nm, err
|
err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_netmap.json", t.hostname), []byte(result), 0o755)
|
||||||
// }
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("saving netmap to /tmp/control: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &nm, err
|
||||||
|
}
|
||||||
|
|
||||||
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
|
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
|
||||||
// This implementation is based on getting the netmap from `tailscale debug watch-ipn`
|
// This implementation is based on getting the netmap from `tailscale debug watch-ipn`
|
||||||
// as there seem to be some weirdness omitting endpoint and DERP info if we use
|
// as there seem to be some weirdness omitting endpoint and DERP info if we use
|
||||||
// Patch updates.
|
// Patch updates.
|
||||||
// This implementation works on all supported versions.
|
// This implementation works on all supported versions.
|
||||||
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
|
// func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
|
||||||
// watch-ipn will only give an update if something is happening,
|
// // watch-ipn will only give an update if something is happening,
|
||||||
// since we send keep alives, the worst case for this should be
|
// // since we send keep alives, the worst case for this should be
|
||||||
// 1 minute, but set a slightly more conservative time.
|
// // 1 minute, but set a slightly more conservative time.
|
||||||
ctx, _ := context.WithTimeout(context.Background(), 3*time.Minute)
|
// ctx, _ := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||||
|
|
||||||
notify, err := t.watchIPN(ctx)
|
// notify, err := t.watchIPN(ctx)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return nil, err
|
// return nil, err
|
||||||
}
|
// }
|
||||||
|
|
||||||
if notify.NetMap == nil {
|
// if notify.NetMap == nil {
|
||||||
return nil, fmt.Errorf("no netmap present in ipn.Notify")
|
// return nil, fmt.Errorf("no netmap present in ipn.Notify")
|
||||||
}
|
// }
|
||||||
|
|
||||||
return notify.NetMap, nil
|
// return notify.NetMap, nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
// watchIPN watches `tailscale debug watch-ipn` for a ipn.Notify object until
|
// watchIPN watches `tailscale debug watch-ipn` for a ipn.Notify object until
|
||||||
// it gets one that has a netmap.NetworkMap.
|
// it gets one that has a netmap.NetworkMap.
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
@ -154,11 +155,11 @@ func assertClientsState(t *testing.T, clients []TailscaleClient) {
|
||||||
func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
// if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) {
|
if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) {
|
||||||
// t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version())
|
t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version())
|
||||||
|
|
||||||
// return
|
return
|
||||||
// }
|
}
|
||||||
|
|
||||||
t.Logf("Checking netmap of %q", client.Hostname())
|
t.Logf("Checking netmap of %q", client.Hostname())
|
||||||
|
|
||||||
|
@ -175,7 +176,11 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
||||||
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
|
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
|
||||||
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
|
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
|
||||||
|
|
||||||
assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname())
|
if netmap.SelfNode.Online() != nil {
|
||||||
|
assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname())
|
||||||
|
} else {
|
||||||
|
t.Errorf("Online should not be nil for %s", client.Hostname())
|
||||||
|
}
|
||||||
|
|
||||||
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
|
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
|
||||||
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
|
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
|
||||||
|
@ -213,7 +218,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
||||||
// This test is not suitable for ACL/partial connection tests.
|
// This test is not suitable for ACL/partial connection tests.
|
||||||
func assertValidStatus(t *testing.T, client TailscaleClient) {
|
func assertValidStatus(t *testing.T, client TailscaleClient) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
status, err := client.Status()
|
status, err := client.Status(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue