diff --git a/hscontrol/app.go b/hscontrol/app.go index 95f731d..ad0a66f 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -257,7 +257,10 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { h.DERPMap.Regions[region.RegionID] = ®ion } - h.nodeNotifier.NotifyAll() + h.nodeNotifier.NotifyAll(types.StateUpdate{ + Type: types.StateDERPUpdated, + DERPMap: *h.DERPMap, + }) } } } @@ -721,7 +724,9 @@ func (h *Headscale) Serve() error { Str("path", aclPath). Msg("ACL policy successfully reloaded, notifying nodes of change") - h.nodeNotifier.NotifyAll() + h.nodeNotifier.NotifyAll(types.StateUpdate{ + Type: types.StateFullUpdate, + }) } default: diff --git a/hscontrol/db/machine.go b/hscontrol/db/machine.go index 936019d..47dfaa1 100644 --- a/hscontrol/db/machine.go +++ b/hscontrol/db/machine.go @@ -13,6 +13,7 @@ import ( "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "gorm.io/gorm" + "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -218,7 +219,10 @@ func (hsdb *HSDatabase) SetTags( } machine.ForcedTags = newTags - hsdb.notifier.NotifyWithIgnore(machine.MachineKey) + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: []uint64{machine.ID}, + }, machine.MachineKey) if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf("failed to update tags for machine in the database: %w", err) @@ -232,7 +236,10 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { now := time.Now() machine.Expiry = &now - hsdb.notifier.NotifyWithIgnore(machine.MachineKey) + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: []uint64{machine.ID}, + }, machine.MachineKey) if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf("failed to expire machine in the database: %w", err) @@ -259,7 +266,10 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er } machine.GivenName = newName - hsdb.notifier.NotifyWithIgnore(machine.MachineKey) + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: []uint64{machine.ID}, + }, machine.MachineKey) if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf("failed to rename machine in the database: %w", err) @@ -275,7 +285,10 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) machine.LastSuccessfulUpdate = &now machine.Expiry = &expiry - hsdb.notifier.NotifyWithIgnore(machine.MachineKey) + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: []uint64{machine.ID}, + }, machine.MachineKey) if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf( @@ -549,6 +562,27 @@ func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) return false } +func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool { + ret := make(map[tailcfg.NodeID]bool) + + for _, peer := range peers { + ret[tailcfg.NodeID(peer.ID)] = peer.IsOnline() + } + + return ret +} + +func (hsdb *HSDatabase) ListOnlineMachines( + machine *types.Machine, +) (map[tailcfg.NodeID]bool, error) { + peers, err := hsdb.ListPeers(machine) + if err != nil { + return nil, err + } + + return OnlineMachineMap(peers), nil +} + // enableRoutes enables new routes based on a list of new routes. func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error { newRoutes := make([]netip.Prefix, len(routeStrs)) @@ -600,7 +634,10 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string } } - hsdb.notifier.NotifyWithIgnore(machine.MachineKey) + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: []uint64{machine.ID}, + }, machine.MachineKey) return nil } @@ -676,12 +713,13 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati return } - expiredFound := false + expired := make([]tailcfg.NodeID, 0) for idx, machine := range machines { if machine.IsEphemeral() && machine.LastSeen != nil && time.Now(). After(machine.LastSeen.Add(inactivityThreshhold)) { - expiredFound = true + expired = append(expired, tailcfg.NodeID(machine.ID)) + log.Info(). Str("machine", machine.Hostname). Msg("Ephemeral client removed from database") @@ -696,8 +734,11 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati } } - if expiredFound { - hsdb.notifier.NotifyAll() + if len(expired) > 0 { + hsdb.notifier.NotifyAll(types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: expired, + }) } } } @@ -726,11 +767,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { return time.Unix(0, 0) } - expiredFound := false + expired := make([]tailcfg.NodeID, 0) for index, machine := range machines { if machine.IsExpired() && machine.Expiry.After(lastCheck) { - expiredFound = true + expired = append(expired, tailcfg.NodeID(machine.ID)) err := hsdb.ExpireMachine(&machines[index]) if err != nil { @@ -748,8 +789,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { } } - if expiredFound { - hsdb.notifier.NotifyAll() + if len(expired) > 0 { + hsdb.notifier.NotifyAll(types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: expired, + }) } } diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index b3604a1..90ec3b1 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -274,7 +274,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { log.Error().Err(err).Msg("error getting routes") } - routesChanged := false + changedMachines := make([]uint64, 0) for pos, route := range routes { if route.IsExitRoute() { continue @@ -295,7 +295,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { return err } - routesChanged = true + changedMachines = append(changedMachines, route.MachineID) continue } @@ -369,12 +369,15 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { return err } - routesChanged = true + changedMachines = append(changedMachines, route.MachineID) } } - if routesChanged { - hsdb.notifier.NotifyAll() + if len(changedMachines) > 0 { + hsdb.notifier.NotifyAll(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: changedMachines, + }) } return nil diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 5d5509b..819a7fb 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/url" + "sort" "strings" "sync" "time" @@ -129,45 +130,35 @@ func fullMapResponse( return nil, err } + // Peers is always returned sorted by Node.ID. + sort.SliceStable(tailPeers, func(x, y int) bool { + return tailPeers[x].ID < tailPeers[y].ID + }) + now := time.Now() resp := tailcfg.MapResponse{ - KeepAlive: false, - Node: tailnode, - - // TODO: Only send if updated - DERPMap: derpMap, - - // TODO: Only send if updated + Node: tailnode, Peers: tailPeers, - // TODO(kradalby): Implement: - // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374 - // PeersChanged - // PeersRemoved - // PeersChangedPatch - // PeerSeenChange - // OnlineChange + DERPMap: derpMap, - // TODO: Only send if updated DNSConfig: dnsConfig, + Domain: baseDomain, - // TODO: Only send if updated - Domain: baseDomain, - - // Do not instruct clients to collect services, we do not + // Do not instruct clients to collect services we do not // support or do anything with them CollectServices: "false", - // TODO: Only send if updated PacketFilter: policy.ReduceFilterRules(machine, rules), UserProfiles: profiles, - // TODO: Only send if updated SSHPolicy: sshPolicy, - ControlTime: &now, + ControlTime: &now, + KeepAlive: false, + OnlineChange: db.OnlineMachineMap(peers), Debug: &tailcfg.Debug{ DisableLogTail: !logtail, @@ -271,8 +262,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) { } } -// CreateMapResponse returns a MapResponse for the given machine. -func (m Mapper) CreateMapResponse( +// FullMapResponse returns a MapResponse for the given machine. +func (m Mapper) FullMapResponse( mapRequest tailcfg.MapRequest, machine *types.Machine, pol *policy.ACLPolicy, @@ -302,39 +293,107 @@ func (m Mapper) CreateMapResponse( } if m.isNoise { - return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress) + return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress) } - var machineKey key.MachinePublic - err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot parse client key") - - return nil, err - } - - return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress) + return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress) } -func (m Mapper) CreateKeepAliveResponse( +func (m Mapper) KeepAliveResponse( mapRequest tailcfg.MapRequest, machine *types.Machine, ) ([]byte, error) { - keepAliveResponse := tailcfg.MapResponse{ - KeepAlive: true, + resp := m.baseMapResponse(machine) + resp.KeepAlive = true + + return m.marshalMapResponse(&resp, machine, mapRequest.Compress) +} + +func (m Mapper) DERPMapResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, + derpMap tailcfg.DERPMap, +) ([]byte, error) { + resp := m.baseMapResponse(machine) + resp.DERPMap = &derpMap + + return m.marshalMapResponse(&resp, machine, mapRequest.Compress) +} + +func (m Mapper) PeerChangedResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, + machineKeys []uint64, + pol *policy.ACLPolicy, +) ([]byte, error) { + var err error + changed := make(types.Machines, len(machineKeys)) + lastSeen := make(map[tailcfg.NodeID]bool) + for idx, machineKey := range machineKeys { + peer, err := m.db.GetMachineByID(machineKey) + if err != nil { + return nil, err + } + + changed[idx] = *peer + + // We have just seen the node, let the peers update their list. + lastSeen[tailcfg.NodeID(peer.ID)] = true } - if m.isNoise { - return m.marshalMapResponse( - keepAliveResponse, - key.MachinePublic{}, - mapRequest.Compress, - ) + rules, _, err := policy.GenerateFilterAndSSHRules( + pol, + machine, + changed, + ) + if err != nil { + return nil, err } + // Filter out peers that have expired. + changed = lo.Filter(changed, func(item types.Machine, index int) bool { + return !item.IsExpired() + }) + + // If there are filter rules present, see if there are any machines that cannot + // access eachother at all and remove them from the changed. + if len(rules) > 0 { + changed = policy.FilterMachinesByACL(machine, changed, rules) + } + + tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain) + if err != nil { + return nil, err + } + + // Peers is always returned sorted by Node.ID. + sort.SliceStable(tailPeers, func(x, y int) bool { + return tailPeers[x].ID < tailPeers[y].ID + }) + + resp := m.baseMapResponse(machine) + resp.PeersChanged = tailPeers + resp.PeerSeenChange = lastSeen + + return m.marshalMapResponse(&resp, machine, mapRequest.Compress) +} + +func (m Mapper) PeerRemovedResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, + removed []tailcfg.NodeID, +) ([]byte, error) { + resp := m.baseMapResponse(machine) + resp.PeersRemoved = removed + + return m.marshalMapResponse(&resp, machine, mapRequest.Compress) +} + +func (m Mapper) marshalMapResponse( + resp *tailcfg.MapResponse, + machine *types.Machine, + compression string, +) ([]byte, error) { var machineKey key.MachinePublic err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) if err != nil { @@ -346,40 +405,6 @@ func (m Mapper) CreateKeepAliveResponse( return nil, err } - return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress) -} - -// MarshalResponse takes an Tailscale Response, marhsal it to JSON. -// If isNoise is set, then the JSON body will be returned -// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box. -func MarshalResponse( - resp interface{}, - isNoise bool, - privateKey2019 *key.MachinePrivate, - machineKey key.MachinePublic, -) ([]byte, error) { - jsonBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot marshal response") - - return nil, err - } - - if !isNoise && privateKey2019 != nil { - return privateKey2019.SealTo(machineKey, jsonBody), nil - } - - return jsonBody, nil -} - -func (m Mapper) marshalMapResponse( - resp interface{}, - machineKey key.MachinePublic, - compression string, -) ([]byte, error) { jsonBody, err := json.Marshal(resp) if err != nil { log.Error(). @@ -409,6 +434,32 @@ func (m Mapper) marshalMapResponse( return data, nil } +// MarshalResponse takes an Tailscale Response, marhsal it to JSON. +// If isNoise is set, then the JSON body will be returned +// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box. +func MarshalResponse( + resp interface{}, + isNoise bool, + privateKey2019 *key.MachinePrivate, + machineKey key.MachinePublic, +) ([]byte, error) { + jsonBody, err := json.Marshal(resp) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot marshal response") + + return nil, err + } + + if !isNoise && privateKey2019 != nil { + return privateKey2019.SealTo(machineKey, jsonBody), nil + } + + return jsonBody, nil +} + func zstdEncode(in []byte) []byte { encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder) if !ok { @@ -433,3 +484,19 @@ var zstdEncoderPool = &sync.Pool{ return encoder }, } + +func (m *Mapper) baseMapResponse(machine *types.Machine) tailcfg.MapResponse { + now := time.Now() + + resp := tailcfg.MapResponse{ + KeepAlive: false, + ControlTime: &now, + } + + online, err := m.db.ListOnlineMachines(machine) + if err == nil { + resp.OnlineChange = online + } + + return resp +} diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 0ca9633..fa3d243 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -387,6 +387,7 @@ func Test_fullMapResponse(t *testing.T) { DNSConfig: &tailcfg.DNSConfig{}, Domain: "", CollectServices: "false", + OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false}, PacketFilter: []tailcfg.FilterRule{}, UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, @@ -428,6 +429,7 @@ func Test_fullMapResponse(t *testing.T) { DNSConfig: &tailcfg.DNSConfig{}, Domain: "", CollectServices: "false", + OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false}, PacketFilter: []tailcfg.FilterRule{ { SrcIPs: []string{"100.64.0.2/32"}, diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index f4e25b2..53fcd6a 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -3,24 +3,25 @@ package notifier import ( "sync" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" ) type Notifier struct { l sync.RWMutex - nodes map[string]chan<- struct{} + nodes map[string]chan<- types.StateUpdate } func NewNotifier() *Notifier { return &Notifier{} } -func (n *Notifier) AddNode(machineKey string, c chan<- struct{}) { +func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) { n.l.Lock() defer n.l.Unlock() if n.nodes == nil { - n.nodes = make(map[string]chan<- struct{}) + n.nodes = make(map[string]chan<- types.StateUpdate) } n.nodes[machineKey] = c @@ -37,11 +38,11 @@ func (n *Notifier) RemoveNode(machineKey string) { delete(n.nodes, machineKey) } -func (n *Notifier) NotifyAll() { - n.NotifyWithIgnore() +func (n *Notifier) NotifyAll(update types.StateUpdate) { + n.NotifyWithIgnore(update) } -func (n *Notifier) NotifyWithIgnore(ignore ...string) { +func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) { n.l.RLock() defer n.l.RUnlock() @@ -50,6 +51,6 @@ func (n *Notifier) NotifyWithIgnore(ignore ...string) { continue } - c <- struct{}{} + c <- update } } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 3b6cde2..eaf759d 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -116,7 +116,7 @@ func (h *Headscale) handlePoll( return } - mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) + mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) if err != nil { logErr(err, "Failed to create MapResponse") http.Error(writer, "", http.StatusInternalServerError) @@ -163,7 +163,12 @@ func (h *Headscale) handlePoll( Inc() // Tell all the other nodes about the new endpoint, but dont update ourselves. - h.nodeNotifier.NotifyWithIgnore(machine.MachineKey) + h.nodeNotifier.NotifyWithIgnore( + types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: []uint64{machine.ID}, + }, + machine.MachineKey) return } else if mapRequest.OmitPeers && mapRequest.Stream { @@ -220,7 +225,7 @@ func (h *Headscale) pollNetMapStream( keepAliveTicker := time.NewTicker(keepAliveInterval) const chanSize = 8 - updateChan := make(chan struct{}, chanSize) + updateChan := make(chan types.StateUpdate, chanSize) h.pollNetMapStreamWG.Add(1) defer h.pollNetMapStreamWG.Done() @@ -238,7 +243,7 @@ func (h *Headscale) pollNetMapStream( for { select { case <-keepAliveTicker.C: - data, err := mapp.CreateKeepAliveResponse(mapRequest, machine) + data, err := mapp.KeepAliveResponse(mapRequest, machine) if err != nil { logErr(err, "Error generating the keep alive msg") @@ -263,10 +268,23 @@ func (h *Headscale) pollNetMapStream( return } - case <-updateChan: - data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) + case update := <-updateChan: + var data []byte + var err error + + switch update.Type { + case types.StateFullUpdate: + data, err = mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) + case types.StatePeerChanged: + data, err = mapp.PeerChangedResponse(mapRequest, machine, update.Changed, h.ACLPolicy) + case types.StatePeerRemoved: + data, err = mapp.PeerRemovedResponse(mapRequest, machine, update.Removed) + case types.StateDERPUpdated: + data, err = mapp.DERPMapResponse(mapRequest, machine, update.DERPMap) + } + if err != nil { - logErr(err, "Could not get the map update") + logErr(err, "Could not get the create map update") return } @@ -317,7 +335,7 @@ func (h *Headscale) pollNetMapStream( } } -func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) { +func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, machine, name string) { log.Trace(). Str("handler", "PollNetMap"). Str("machine", machine). diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 96ad1b7..3a00104 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -106,3 +106,32 @@ func (i StringList) Value() (driver.Value, error) { return string(bytes), err } + +type StateUpdateType int + +const ( + StateFullUpdate StateUpdateType = iota + StatePeerChanged + StatePeerRemoved + StateDERPUpdated +) + +// StateUpdate is an internal message containing information about +// a state change that has happened to the network. +type StateUpdate struct { + // The type of update + Type StateUpdateType + + // Changed must be set when Type is StatePeerChanged and + // contain the Machine IDs of machines that has changed. + Changed []uint64 + + // Removed must be set when Type is StatePeerRemoved and + // contain a list of the nodes that has been removed from + // the network. + Removed []tailcfg.NodeID + + // DERPMap must be set when Type is StateDERPUpdated and + // contain the new DERP Map. + DERPMap tailcfg.DERPMap +}