Replace the timestamp based state system
This commit replaces the timestamp based state system with a new one that has update channels directly to the connected nodes. It will send an update to all listening clients via the polling mechanism. It introduces a new package notifier, which has a concurrency safe manager for all our channels to the connected nodes. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
056d3a81c5
commit
66ff1fcd40
13 changed files with 216 additions and 731 deletions
108
hscontrol/app.go
108
hscontrol/app.go
|
@ -10,7 +10,6 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -26,13 +25,13 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/derp"
|
||||
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/patrickmn/go-cache"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/puzpuzpuz/xsync/v2"
|
||||
zl "github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/crypto/acme"
|
||||
|
@ -84,7 +83,7 @@ type Headscale struct {
|
|||
|
||||
ACLPolicy *policy.ACLPolicy
|
||||
|
||||
lastStateChange *xsync.MapOf[string, time.Time]
|
||||
nodeNotifier *notifier.Notifier
|
||||
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
|
@ -93,9 +92,6 @@ type Headscale struct {
|
|||
|
||||
shutdownChan chan struct{}
|
||||
pollNetMapStreamWG sync.WaitGroup
|
||||
|
||||
stateUpdateChan chan struct{}
|
||||
cancelStateUpdateChan chan struct{}
|
||||
}
|
||||
|
||||
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
|
@ -158,19 +154,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
noisePrivateKey: noisePrivateKey,
|
||||
registrationCache: registrationCache,
|
||||
pollNetMapStreamWG: sync.WaitGroup{},
|
||||
lastStateChange: xsync.NewMapOf[time.Time](),
|
||||
|
||||
stateUpdateChan: make(chan struct{}),
|
||||
cancelStateUpdateChan: make(chan struct{}),
|
||||
nodeNotifier: notifier.NewNotifier(),
|
||||
}
|
||||
|
||||
go app.watchStateChannel()
|
||||
|
||||
database, err := db.NewHeadscaleDatabase(
|
||||
cfg.DBtype,
|
||||
dbString,
|
||||
app.dbDebug,
|
||||
app.stateUpdateChan,
|
||||
app.nodeNotifier,
|
||||
cfg.IPPrefixes,
|
||||
cfg.BaseDomain)
|
||||
if err != nil {
|
||||
|
@ -203,7 +194,11 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
|
||||
if cfg.DERP.ServerEnabled {
|
||||
// TODO(kradalby): replace this key with a dedicated DERP key.
|
||||
embeddedDERPServer, err := derpServer.NewDERPServer(cfg.ServerURL, key.NodePrivate(*privateKey), &cfg.DERP)
|
||||
embeddedDERPServer, err := derpServer.NewDERPServer(
|
||||
cfg.ServerURL,
|
||||
key.NodePrivate(*privateKey),
|
||||
&cfg.DERP,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -230,10 +225,14 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
|||
|
||||
// expireExpiredMachines expires machines that have an explicit expiry set
|
||||
// after that expiry time has passed.
|
||||
func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
func (h *Headscale) expireExpiredMachines(intervalMs int64) {
|
||||
interval := time.Duration(intervalMs) * time.Millisecond
|
||||
ticker := time.NewTicker(interval)
|
||||
|
||||
lastCheck := time.Unix(0, 0)
|
||||
|
||||
for range ticker.C {
|
||||
h.db.ExpireExpiredMachines(h.getLastStateChange())
|
||||
lastCheck = h.db.ExpireExpiredMachines(lastCheck)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -258,7 +257,7 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
|
|||
h.DERPMap.Regions[region.RegionID] = ®ion
|
||||
}
|
||||
|
||||
h.setLastStateChangeToNow()
|
||||
h.nodeNotifier.NotifyAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -722,7 +721,7 @@ func (h *Headscale) Serve() error {
|
|||
Str("path", aclPath).
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
|
||||
h.setLastStateChangeToNow()
|
||||
h.nodeNotifier.NotifyAll()
|
||||
}
|
||||
|
||||
default:
|
||||
|
@ -760,10 +759,6 @@ func (h *Headscale) Serve() error {
|
|||
// Stop listening (and unlink the socket if unix type):
|
||||
socketListener.Close()
|
||||
|
||||
<-h.cancelStateUpdateChan
|
||||
close(h.stateUpdateChan)
|
||||
close(h.cancelStateUpdateChan)
|
||||
|
||||
// Close db connections
|
||||
err = h.db.Close()
|
||||
if err != nil {
|
||||
|
@ -859,73 +854,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(kradalby): baby steps, make this more robust.
|
||||
func (h *Headscale) watchStateChannel() {
|
||||
for {
|
||||
select {
|
||||
case <-h.stateUpdateChan:
|
||||
h.setLastStateChangeToNow()
|
||||
|
||||
case <-h.cancelStateUpdateChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) setLastStateChangeToNow() {
|
||||
var err error
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
users, err := h.db.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("failed to fetch all users, failing to update last changed state.")
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
lastStateUpdate.WithLabelValues(user.Name, "headscale").Set(float64(now.Unix()))
|
||||
if h.lastStateChange == nil {
|
||||
h.lastStateChange = xsync.NewMapOf[time.Time]()
|
||||
}
|
||||
h.lastStateChange.Store(user.Name, now)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) getLastStateChange(users ...types.User) time.Time {
|
||||
times := []time.Time{}
|
||||
|
||||
// getLastStateChange takes a list of users as a "filter", if no users
|
||||
// are past, then use the entier list of users and look for the last update
|
||||
if len(users) > 0 {
|
||||
for _, user := range users {
|
||||
if lastChange, ok := h.lastStateChange.Load(user.Name); ok {
|
||||
times = append(times, lastChange)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
h.lastStateChange.Range(func(key string, value time.Time) bool {
|
||||
times = append(times, value)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(times, func(i, j int) bool {
|
||||
return times[i].After(times[j])
|
||||
})
|
||||
|
||||
log.Trace().Msgf("Latest times %#v", times)
|
||||
|
||||
if len(times) == 0 {
|
||||
return time.Now().UTC()
|
||||
} else {
|
||||
return times[0]
|
||||
}
|
||||
}
|
||||
|
||||
func notFoundHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
|
|
|
@ -63,8 +63,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
|||
|
||||
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
||||
c.Assert(machine1.IPAddresses[0], check.Equals, expected)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
|
@ -153,8 +151,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
|||
|
||||
c.Assert(len(nextIP2), check.Equals, 1)
|
||||
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||
|
@ -192,6 +188,4 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
|||
|
||||
c.Assert(len(ips2), check.Equals, 1)
|
||||
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
|
|
@ -22,8 +22,6 @@ func (*Suite) TestCreateAPIKey(c *check.C) {
|
|||
keys, err := db.ListAPIKeys()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(keys), check.Equals, 1)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
|
||||
|
@ -41,8 +39,6 @@ func (*Suite) TestValidateAPIKeyOk(c *check.C) {
|
|||
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(valid, check.Equals, true)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
||||
|
@ -71,8 +67,6 @@ func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
|||
validWithErr, err := db.ValidateAPIKey("produceerrorkey")
|
||||
c.Assert(err, check.NotNil)
|
||||
c.Assert(validWithErr, check.Equals, false)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestExpireAPIKey(c *check.C) {
|
||||
|
@ -92,6 +86,4 @@ func (*Suite) TestExpireAPIKey(c *check.C) {
|
|||
notValid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(notValid, check.Equals, false)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -37,7 +38,7 @@ type KV struct {
|
|||
|
||||
type HSDatabase struct {
|
||||
db *gorm.DB
|
||||
notifyStateChan chan<- struct{}
|
||||
notifier *notifier.Notifier
|
||||
|
||||
ipAllocationMutex sync.Mutex
|
||||
|
||||
|
@ -50,7 +51,7 @@ type HSDatabase struct {
|
|||
func NewHeadscaleDatabase(
|
||||
dbType, connectionAddr string,
|
||||
debug bool,
|
||||
notifyStateChan chan<- struct{},
|
||||
notifier *notifier.Notifier,
|
||||
ipPrefixes []netip.Prefix,
|
||||
baseDomain string,
|
||||
) (*HSDatabase, error) {
|
||||
|
@ -61,7 +62,7 @@ func NewHeadscaleDatabase(
|
|||
|
||||
db := HSDatabase{
|
||||
db: dbConn,
|
||||
notifyStateChan: notifyStateChan,
|
||||
notifier: notifier,
|
||||
|
||||
ipPrefixes: ipPrefixes,
|
||||
baseDomain: baseDomain,
|
||||
|
@ -297,10 +298,6 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
|||
)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) notifyStateChange() {
|
||||
hsdb.notifyStateChan <- struct{}{}
|
||||
}
|
||||
|
||||
// getValue returns the value for the given key in KV.
|
||||
func (hsdb *HSDatabase) getValue(key string) (string, error) {
|
||||
var row KV
|
||||
|
|
|
@ -218,7 +218,7 @@ func (hsdb *HSDatabase) SetTags(
|
|||
}
|
||||
machine.ForcedTags = newTags
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
|
||||
|
@ -232,7 +232,7 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error {
|
|||
now := time.Now()
|
||||
machine.Expiry = &now
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to expire machine in the database: %w", err)
|
||||
|
@ -259,7 +259,7 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er
|
|||
}
|
||||
machine.GivenName = newName
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
||||
|
@ -275,7 +275,7 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time)
|
|||
machine.LastSuccessfulUpdate = &now
|
||||
machine.Expiry = &expiry
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
|
@ -323,32 +323,6 @@ func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) IsOutdated(machine *types.Machine, lastChange time.Time) bool {
|
||||
if err := hsdb.UpdateMachineFromDatabase(machine); err != nil {
|
||||
// It does not seem meaningful to propagate this error as the end result
|
||||
// will have to be that the machine has to be considered outdated.
|
||||
return true
|
||||
}
|
||||
|
||||
// Get the last update from all headscale users to compare with our nodes
|
||||
// last update.
|
||||
// TODO(kradalby): Only request updates from users where we can talk to nodes
|
||||
// This would mostly be for a bit of performance, and can be calculated based on
|
||||
// ACLs.
|
||||
lastUpdate := machine.CreatedAt
|
||||
if machine.LastSuccessfulUpdate != nil {
|
||||
lastUpdate = *machine.LastSuccessfulUpdate
|
||||
}
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Hostname).
|
||||
Time("last_successful_update", lastChange).
|
||||
Time("last_state_change", lastUpdate).
|
||||
Msgf("Checking if %s is missing updates", machine.Hostname)
|
||||
|
||||
return lastUpdate.Before(lastChange)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterMachineFromAuthCallback(
|
||||
cache *cache.Cache,
|
||||
nodeKeyStr string,
|
||||
|
@ -626,7 +600,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
|
|||
}
|
||||
}
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -723,17 +697,22 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
|
|||
}
|
||||
|
||||
if expiredFound {
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
|
||||
func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
|
||||
// use the time of the start of the function to ensure we
|
||||
// dont miss some machines by returning it _after_ we have
|
||||
// checked everything.
|
||||
started := time.Now()
|
||||
|
||||
users, err := hsdb.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error listing users")
|
||||
|
||||
return
|
||||
return time.Unix(0, 0)
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
|
@ -744,13 +723,13 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
|
|||
Str("user", user.Name).
|
||||
Msg("Error listing machines in user")
|
||||
|
||||
return
|
||||
return time.Unix(0, 0)
|
||||
}
|
||||
|
||||
expiredFound := false
|
||||
for index, machine := range machines {
|
||||
if machine.IsExpired() &&
|
||||
machine.Expiry.After(lastChange) {
|
||||
machine.Expiry.After(lastCheck) {
|
||||
expiredFound = true
|
||||
|
||||
err := hsdb.ExpireMachine(&machines[index])
|
||||
|
@ -770,7 +749,9 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
|
|||
}
|
||||
|
||||
if expiredFound {
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyAll()
|
||||
}
|
||||
}
|
||||
|
||||
return started
|
||||
}
|
||||
|
|
|
@ -39,8 +39,6 @@ func (s *Suite) TestGetMachine(c *check.C) {
|
|||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByID(c *check.C) {
|
||||
|
@ -67,8 +65,6 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
|
|||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
||||
|
@ -98,8 +94,6 @@ func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
|||
|
||||
_, err = db.GetMachineByNodeKey(nodeKey.Public())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
||||
|
@ -131,8 +125,6 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
|||
|
||||
_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||
|
@ -155,8 +147,6 @@ func (s *Suite) TestDeleteMachine(c *check.C) {
|
|||
|
||||
_, err = db.GetMachine(user.Name, "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
||||
|
@ -179,8 +169,6 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
|||
|
||||
_, err = db.GetMachine(user.Name, "testmachine3")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestListPeers(c *check.C) {
|
||||
|
@ -217,8 +205,6 @@ func (s *Suite) TestListPeers(c *check.C) {
|
|||
c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2")
|
||||
c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7")
|
||||
c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10")
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||
|
@ -312,8 +298,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||
c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2")
|
||||
c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4")
|
||||
c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7")
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestExpireMachine(c *check.C) {
|
||||
|
@ -349,8 +333,6 @@ func (s *Suite) TestExpireMachine(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(machineFromDB.IsExpired(), check.Equals, true)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(1))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
||||
|
@ -372,8 +354,6 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
|||
for i := range deserialized {
|
||||
c.Assert(deserialized[i], check.Equals, input[i])
|
||||
}
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||
|
@ -418,8 +398,6 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
|
|||
comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetTags(c *check.C) {
|
||||
|
@ -463,8 +441,6 @@ func (s *Suite) TestSetTags(c *check.C) {
|
|||
check.DeepEquals,
|
||||
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
|
||||
)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(2))
|
||||
}
|
||||
|
||||
func TestHeadscale_generateGivenName(t *testing.T) {
|
||||
|
@ -655,6 +631,4 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
|||
enabledRoutes, err := db.GetEnabledRoutes(machine0ByID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(enabledRoutes, check.HasLen, 4)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(4))
|
||||
}
|
||||
|
|
|
@ -161,8 +161,6 @@ func (*Suite) TestEphemeralKey(c *check.C) {
|
|||
// The machine record should have been deleted
|
||||
_, err = db.GetMachine("test7", "testest")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(1))
|
||||
}
|
||||
|
||||
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||
|
|
|
@ -374,7 +374,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
|||
}
|
||||
|
||||
if routesChanged {
|
||||
hsdb.notifyStateChange()
|
||||
hsdb.notifier.NotifyAll()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -52,8 +52,6 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
|||
|
||||
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||
|
@ -129,8 +127,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
|||
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(3))
|
||||
}
|
||||
|
||||
func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||
|
@ -215,8 +211,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
|||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 0)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(3))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||
|
@ -359,8 +353,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
|||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 2)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(6))
|
||||
}
|
||||
|
||||
func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||
|
@ -420,6 +412,4 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
|||
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(2))
|
||||
}
|
||||
|
|
|
@ -3,9 +3,9 @@ package db
|
|||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
|
@ -20,14 +20,9 @@ type Suite struct{}
|
|||
var (
|
||||
tmpDir string
|
||||
db *HSDatabase
|
||||
|
||||
// channelUpdates counts the number of times
|
||||
// either of the channels was notified.
|
||||
channelUpdates int32
|
||||
)
|
||||
|
||||
func (s *Suite) SetUpTest(c *check.C) {
|
||||
atomic.StoreInt32(&channelUpdates, 0)
|
||||
s.ResetDB(c)
|
||||
}
|
||||
|
||||
|
@ -35,13 +30,6 @@ func (s *Suite) TearDownTest(c *check.C) {
|
|||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
func notificationSink(c <-chan struct{}) {
|
||||
for {
|
||||
<-c
|
||||
atomic.AddInt32(&channelUpdates, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Suite) ResetDB(c *check.C) {
|
||||
if len(tmpDir) != 0 {
|
||||
os.RemoveAll(tmpDir)
|
||||
|
@ -52,15 +40,11 @@ func (s *Suite) ResetDB(c *check.C) {
|
|||
c.Fatal(err)
|
||||
}
|
||||
|
||||
sink := make(chan struct{})
|
||||
|
||||
go notificationSink(sink)
|
||||
|
||||
db, err = NewHeadscaleDatabase(
|
||||
"sqlite3",
|
||||
tmpDir+"/headscale_test.db",
|
||||
false,
|
||||
sink,
|
||||
notifier.NewNotifier(),
|
||||
[]netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
|
|
|
@ -10,12 +10,6 @@ const prometheusNamespace = "headscale"
|
|||
var (
|
||||
// This is a high cardinality metric (user x machines), we might want to make this
|
||||
// configurable/opt-in in the future.
|
||||
lastStateUpdate = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "last_update_seconds",
|
||||
Help: "Time stamp in unix time when a machine or headscale was updated",
|
||||
}, []string{"user", "machine"})
|
||||
|
||||
machineRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "machine_registrations_total",
|
||||
|
@ -33,9 +27,4 @@ var (
|
|||
Help: "The number of calls/messages issued on a specific nodes update channel",
|
||||
}, []string{"user", "machine", "status"})
|
||||
// TODO(kradalby): This is very debugging, we might want to remove it.
|
||||
updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "update_request_received_on_channel_total",
|
||||
Help: "The number of update requests received on an update channel",
|
||||
}, []string{"user", "machine"})
|
||||
)
|
||||
|
|
55
hscontrol/notifier/notifier.go
Normal file
55
hscontrol/notifier/notifier.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package notifier
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
l sync.RWMutex
|
||||
nodes map[string]chan<- struct{}
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *Notifier) AddNode(machineKey string, c chan<- struct{}) {
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
|
||||
if n.nodes == nil {
|
||||
n.nodes = make(map[string]chan<- struct{})
|
||||
}
|
||||
|
||||
n.nodes[machineKey] = c
|
||||
}
|
||||
|
||||
func (n *Notifier) RemoveNode(machineKey string) {
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
|
||||
if n.nodes == nil {
|
||||
return
|
||||
}
|
||||
|
||||
delete(n.nodes, machineKey)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyAll() {
|
||||
n.NotifyWithIgnore()
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyWithIgnore(ignore ...string) {
|
||||
n.l.RLock()
|
||||
defer n.l.RUnlock()
|
||||
|
||||
for key, c := range n.nodes {
|
||||
if util.IsStringInSlice(ignore, key) {
|
||||
continue
|
||||
}
|
||||
|
||||
c <- struct{}{}
|
||||
}
|
||||
}
|
|
@ -21,6 +21,38 @@ type contextKey string
|
|||
|
||||
const machineNameContextKey = contextKey("machineName")
|
||||
|
||||
type UpdateNode func()
|
||||
|
||||
func logPollFunc(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
isNoise bool,
|
||||
) (func(string), func(error, string)) {
|
||||
return func(msg string) {
|
||||
log.Info().
|
||||
Caller().
|
||||
Bool("noise", isNoise).
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Str("node_key", machine.NodeKey).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg(msg)
|
||||
},
|
||||
func(err error, msg string) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Bool("noise", isNoise).
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Str("node_key", machine.NodeKey).
|
||||
Str("machine", machine.Hostname).
|
||||
Err(err).
|
||||
Msg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePoll is the common code for the legacy and Noise protocols to
|
||||
// managed the poll loop.
|
||||
func (h *Headscale) handlePoll(
|
||||
|
@ -30,6 +62,10 @@ func (h *Headscale) handlePoll(
|
|||
mapRequest tailcfg.MapRequest,
|
||||
isNoise bool,
|
||||
) {
|
||||
logInfo, logErr := logPollFunc(mapRequest, machine, isNoise)
|
||||
|
||||
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
||||
// per client or something similar
|
||||
mapp := mapper.NewMapper(
|
||||
h.db,
|
||||
h.privateKey2019,
|
||||
|
@ -48,11 +84,7 @@ func (h *Headscale) handlePoll(
|
|||
|
||||
err := h.db.ProcessMachineRoutes(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Error processing machine routes")
|
||||
logErr(err, "Error processing machine routes")
|
||||
}
|
||||
|
||||
// update ACLRules with peer informations (to update server tags if necessary)
|
||||
|
@ -60,12 +92,7 @@ func (h *Headscale) handlePoll(
|
|||
// update routes with peer information
|
||||
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Err(err).
|
||||
Msg("Error running auto approved routes")
|
||||
logErr(err, "Error running auto approved routes")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,13 +110,7 @@ func (h *Headscale) handlePoll(
|
|||
}
|
||||
|
||||
if err := h.db.MachineSave(machine); err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Bool("noise", isNoise).
|
||||
Str("node_key", machine.NodeKey).
|
||||
Str("machine", machine.Hostname).
|
||||
Err(err).
|
||||
Msg("Failed to persist/update machine in the database")
|
||||
logErr(err, "Failed to persist/update machine in the database")
|
||||
http.Error(writer, "", http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
|
@ -97,13 +118,7 @@ func (h *Headscale) handlePoll(
|
|||
|
||||
mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Bool("noise", isNoise).
|
||||
Str("node_key", machine.NodeKey).
|
||||
Str("machine", machine.Hostname).
|
||||
Err(err).
|
||||
Msg("Failed to get Map response")
|
||||
logErr(err, "Failed to create MapResponse")
|
||||
http.Error(writer, "", http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
|
@ -114,30 +129,16 @@ func (h *Headscale) handlePoll(
|
|||
// empty endpoints to peers)
|
||||
|
||||
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
|
||||
log.Debug().
|
||||
Str("handler", "PollNetMap").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Msg("Client map request processed")
|
||||
logInfo("Client map request processed")
|
||||
|
||||
if mapRequest.ReadOnly {
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Client is starting up. Probably interested in a DERP map")
|
||||
logInfo("Client is starting up. Probably interested in a DERP map")
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
_, err := writer.Write(mapResp)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
logErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
if f, ok := writer.(http.Flusher); ok {
|
||||
|
@ -147,48 +148,22 @@ func (h *Headscale) handlePoll(
|
|||
return
|
||||
}
|
||||
|
||||
// There has been an update to _any_ of the nodes that the other nodes would
|
||||
// need to know about
|
||||
h.setLastStateChangeToNow()
|
||||
|
||||
// The request is not ReadOnly, so we need to set up channels for updating
|
||||
// peers via longpoll
|
||||
|
||||
// Only create update channel if it has not been created
|
||||
log.Trace().
|
||||
Caller().
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Loading or creating update channel")
|
||||
|
||||
const chanSize = 8
|
||||
updateChan := make(chan struct{}, chanSize)
|
||||
|
||||
pollDataChan := make(chan []byte, chanSize)
|
||||
defer closeChanWithLog(pollDataChan, machine.Hostname, "pollDataChan")
|
||||
|
||||
keepAliveChan := make(chan []byte)
|
||||
|
||||
if mapRequest.OmitPeers && !mapRequest.Stream {
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Client sent endpoint update and is ok with a response without peer list")
|
||||
logInfo("Client sent endpoint update and is ok with a response without peer list")
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
_, err := writer.Write(mapResp)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
logErr(err, "Failed to write response")
|
||||
}
|
||||
// It sounds like we should update the nodes when we have received a endpoint update
|
||||
// even tho the comments in the tailscale code dont explicitly say so.
|
||||
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "endpoint-update").
|
||||
Inc()
|
||||
updateChan <- struct{}{}
|
||||
|
||||
// Tell all the other nodes about the new endpoint, but dont update ourselves.
|
||||
h.nodeNotifier.NotifyWithIgnore(machine.MachineKey)
|
||||
|
||||
return
|
||||
} else if mapRequest.OmitPeers && mapRequest.Stream {
|
||||
|
@ -202,43 +177,32 @@ func (h *Headscale) handlePoll(
|
|||
return
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Client is ready to access the tailnet")
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Sending initial map")
|
||||
pollDataChan <- mapResp
|
||||
logInfo("Sending initial map")
|
||||
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Notifying peers")
|
||||
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "full-update").
|
||||
Inc()
|
||||
updateChan <- struct{}{}
|
||||
// Send the client an update to make sure we send an initial mapresponse
|
||||
_, err = writer.Write(mapResp)
|
||||
if err != nil {
|
||||
logErr(err, "Could not write the map response")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if flusher, ok := writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
h.pollNetMapStream(
|
||||
writer,
|
||||
ctx,
|
||||
machine,
|
||||
mapp,
|
||||
mapRequest,
|
||||
pollDataChan,
|
||||
keepAliveChan,
|
||||
updateChan,
|
||||
isNoise,
|
||||
)
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Finished stream, closing PollNetMap session")
|
||||
logInfo("Finished stream, closing PollNetMap session")
|
||||
}
|
||||
|
||||
// pollNetMapStream stream logic for /machine/map,
|
||||
|
@ -247,23 +211,16 @@ func (h *Headscale) pollNetMapStream(
|
|||
writer http.ResponseWriter,
|
||||
ctxReq context.Context,
|
||||
machine *types.Machine,
|
||||
mapp *mapper.Mapper,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
pollDataChan chan []byte,
|
||||
keepAliveChan chan []byte,
|
||||
updateChan chan struct{},
|
||||
isNoise bool,
|
||||
) {
|
||||
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
||||
// per client or something similar
|
||||
mapp := mapper.NewMapper(h.db,
|
||||
h.privateKey2019,
|
||||
isNoise,
|
||||
h.DERPMap,
|
||||
h.cfg.BaseDomain,
|
||||
h.cfg.DNSConfig,
|
||||
h.cfg.LogTail.Enabled,
|
||||
h.cfg.RandomizeClientPort,
|
||||
)
|
||||
logInfo, logErr := logPollFunc(mapRequest, machine, isNoise)
|
||||
|
||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||
|
||||
const chanSize = 8
|
||||
updateChan := make(chan struct{}, chanSize)
|
||||
|
||||
h.pollNetMapStreamWG.Add(1)
|
||||
defer h.pollNetMapStreamWG.Done()
|
||||
|
@ -273,447 +230,93 @@ func (h *Headscale) pollNetMapStream(
|
|||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go h.scheduledPollWorker(
|
||||
ctx,
|
||||
updateChan,
|
||||
keepAliveChan,
|
||||
mapRequest,
|
||||
machine,
|
||||
isNoise,
|
||||
)
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "pollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Waiting for data to stream...")
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "pollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
|
||||
// Register the node's update channel
|
||||
h.nodeNotifier.AddNode(machine.MachineKey, updateChan)
|
||||
defer h.nodeNotifier.RemoveNode(machine.MachineKey)
|
||||
defer closeChanWithLog(updateChan, machine.Hostname, "updateChan")
|
||||
|
||||
for {
|
||||
select {
|
||||
case data := <-pollDataChan:
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "pollData").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Sending data received via pollData channel")
|
||||
_, err := writer.Write(data)
|
||||
case <-keepAliveTicker.C:
|
||||
data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "pollData").
|
||||
Err(err).
|
||||
Msg("Cannot write data")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "pollData").
|
||||
Msg("Cannot cast writer to http.Flusher")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "pollData").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Data from pollData channel written successfully")
|
||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||
// when an outdated machine object is kept alive, e.g. db is update from
|
||||
// command line, but then overwritten.
|
||||
err = h.db.UpdateMachineFromDatabase(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "pollData").
|
||||
Err(err).
|
||||
Msg("Cannot update machine from database")
|
||||
|
||||
// client has been removed from database
|
||||
// since the stream opened, terminate connection.
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
machine.LastSeen = &now
|
||||
|
||||
lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname).
|
||||
Set(float64(now.Unix()))
|
||||
machine.LastSuccessfulUpdate = &now
|
||||
|
||||
err = h.db.TouchMachine(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "pollData").
|
||||
Err(err).
|
||||
Msg("Cannot update machine LastSuccessfulUpdate")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "pollData").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Machine entry in database updated successfully after sending data")
|
||||
|
||||
case data := <-keepAliveChan:
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "keepAlive").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Sending keep alive message")
|
||||
_, err := writer.Write(data)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "keepAlive").
|
||||
Err(err).
|
||||
Msg("Cannot write keep alive message")
|
||||
|
||||
return
|
||||
}
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "keepAlive").
|
||||
Msg("Cannot cast writer to http.Flusher")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "keepAlive").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Keep alive sent successfully")
|
||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||
// when an outdated machine object is kept alive, e.g. db is update from
|
||||
// command line, but then overwritten.
|
||||
err = h.db.UpdateMachineFromDatabase(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "keepAlive").
|
||||
Err(err).
|
||||
Msg("Cannot update machine from database")
|
||||
|
||||
// client has been removed from database
|
||||
// since the stream opened, terminate connection.
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
machine.LastSeen = &now
|
||||
err = h.db.TouchMachine(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "keepAlive").
|
||||
Err(err).
|
||||
Msg("Cannot update machine LastSeen")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "keepAlive").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Machine updated successfully after sending keep alive")
|
||||
|
||||
case <-updateChan:
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "update").
|
||||
Msg("Received a request for update")
|
||||
updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname).
|
||||
Inc()
|
||||
|
||||
if h.db.IsOutdated(machine, h.getLastStateChange()) {
|
||||
var lastUpdate time.Time
|
||||
if machine.LastSuccessfulUpdate != nil {
|
||||
lastUpdate = *machine.LastSuccessfulUpdate
|
||||
}
|
||||
log.Debug().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Time("last_successful_update", lastUpdate).
|
||||
Time("last_state_change", h.getLastStateChange(machine.User)).
|
||||
Msgf("There has been updates since the last successful update to %s", machine.Hostname)
|
||||
data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "update").
|
||||
Err(err).
|
||||
Msg("Could not get the map update")
|
||||
logErr(err, "Error generating the keep alive msg")
|
||||
|
||||
return
|
||||
}
|
||||
_, err = writer.Write(data)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "update").
|
||||
Err(err).
|
||||
Msg("Could not write the map response")
|
||||
logErr(err, "Cannot write keep alive message")
|
||||
|
||||
return
|
||||
}
|
||||
if flusher, ok := writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.TouchMachine(machine)
|
||||
if err != nil {
|
||||
logErr(err, "Cannot update machine LastSeen")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
case <-updateChan:
|
||||
data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||
if err != nil {
|
||||
logErr(err, "Could not get the map update")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
_, err = writer.Write(data)
|
||||
if err != nil {
|
||||
logErr(err, "Could not write the map response")
|
||||
|
||||
updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "failed").
|
||||
Inc()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "update").
|
||||
Msg("Cannot cast writer to http.Flusher")
|
||||
} else {
|
||||
if flusher, ok := writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "update").
|
||||
Msg("Updated Map has been sent")
|
||||
updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "success").
|
||||
Inc()
|
||||
|
||||
// Keep track of the last successful update,
|
||||
// we sometimes end in a state were the update
|
||||
// is not picked up by a client and we use this
|
||||
// to determine if we should "force" an update.
|
||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||
// when an outdated machine object is kept alive, e.g. db is update from
|
||||
// command line, but then overwritten.
|
||||
err = h.db.UpdateMachineFromDatabase(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "update").
|
||||
Err(err).
|
||||
Msg("Cannot update machine from database")
|
||||
|
||||
// client has been removed from database
|
||||
// since the stream opened, terminate connection.
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
|
||||
lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname).
|
||||
Set(float64(now.Unix()))
|
||||
machine.LastSuccessfulUpdate = &now
|
||||
|
||||
err = h.db.TouchMachine(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "update").
|
||||
Err(err).
|
||||
Msg("Cannot update machine LastSuccessfulUpdate")
|
||||
logErr(err, "Cannot update machine LastSuccessfulUpdate")
|
||||
|
||||
return
|
||||
}
|
||||
} else {
|
||||
var lastUpdate time.Time
|
||||
if machine.LastSuccessfulUpdate != nil {
|
||||
lastUpdate = *machine.LastSuccessfulUpdate
|
||||
}
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Time("last_successful_update", lastUpdate).
|
||||
Time("last_state_change", h.getLastStateChange(machine.User)).
|
||||
Msgf("%s is up to date", machine.Hostname)
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
log.Info().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("The client has closed the connection")
|
||||
// TODO: Abstract away all the database calls, this can cause race conditions
|
||||
// when an outdated machine object is kept alive, e.g. db is update from
|
||||
// command line, but then overwritten.
|
||||
err := h.db.UpdateMachineFromDatabase(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "Done").
|
||||
Err(err).
|
||||
Msg("Cannot update machine from database")
|
||||
logInfo("The client has closed the connection")
|
||||
|
||||
// client has been removed from database
|
||||
// since the stream opened, terminate connection.
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
machine.LastSeen = &now
|
||||
err = h.db.TouchMachine(machine)
|
||||
err := h.db.TouchMachine(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("channel", "Done").
|
||||
Err(err).
|
||||
Msg("Cannot update machine LastSeen")
|
||||
logErr(err, "Cannot update machine LastSeen")
|
||||
}
|
||||
|
||||
// The connection has been closed, so we can stop polling.
|
||||
return
|
||||
|
||||
case <-h.shutdownChan:
|
||||
log.Info().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("The long-poll handler is shutting down")
|
||||
logInfo("The long-poll handler is shutting down")
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) scheduledPollWorker(
|
||||
ctx context.Context,
|
||||
updateChan chan struct{},
|
||||
keepAliveChan chan []byte,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *types.Machine,
|
||||
isNoise bool,
|
||||
) {
|
||||
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
||||
// per client or something similar
|
||||
mapp := mapper.NewMapper(h.db,
|
||||
h.privateKey2019,
|
||||
isNoise,
|
||||
h.DERPMap,
|
||||
h.cfg.BaseDomain,
|
||||
h.cfg.DNSConfig,
|
||||
h.cfg.LogTail.Enabled,
|
||||
h.cfg.RandomizeClientPort,
|
||||
)
|
||||
|
||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||
updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval)
|
||||
|
||||
defer closeChanWithLog(
|
||||
updateChan,
|
||||
fmt.Sprint(ctx.Value(machineNameContextKey)),
|
||||
"updateChan",
|
||||
)
|
||||
defer closeChanWithLog(
|
||||
keepAliveChan,
|
||||
fmt.Sprint(ctx.Value(machineNameContextKey)),
|
||||
"keepAliveChan",
|
||||
)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case <-keepAliveTicker.C:
|
||||
data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "keepAlive").
|
||||
Bool("noise", isNoise).
|
||||
Err(err).
|
||||
Msg("Error generating the keep alive msg")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("func", "keepAlive").
|
||||
Str("machine", machine.Hostname).
|
||||
Bool("noise", isNoise).
|
||||
Msg("Sending keepalive")
|
||||
select {
|
||||
case keepAliveChan <- data:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
case <-updateCheckerTicker.C:
|
||||
log.Debug().
|
||||
Str("func", "scheduledPollWorker").
|
||||
Str("machine", machine.Hostname).
|
||||
Bool("noise", isNoise).
|
||||
Msg("Sending update request")
|
||||
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "scheduled-update").
|
||||
Inc()
|
||||
select {
|
||||
case updateChan <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) {
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
|
|
Loading…
Reference in a new issue