Split up MapResponse
This commits extends the mapper with functions for creating "delta" MapResponses for different purposes (peer changed, peer removed, derp). This wires up the new state management with a new StateUpdate struct letting the poll worker know what kind of update to send to the connected nodes. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
66ff1fcd40
commit
4b65cf48d0
8 changed files with 284 additions and 115 deletions
|
@ -257,7 +257,10 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
|
||||||
h.DERPMap.Regions[region.RegionID] = ®ion
|
h.DERPMap.Regions[region.RegionID] = ®ion
|
||||||
}
|
}
|
||||||
|
|
||||||
h.nodeNotifier.NotifyAll()
|
h.nodeNotifier.NotifyAll(types.StateUpdate{
|
||||||
|
Type: types.StateDERPUpdated,
|
||||||
|
DERPMap: *h.DERPMap,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -721,7 +724,9 @@ func (h *Headscale) Serve() error {
|
||||||
Str("path", aclPath).
|
Str("path", aclPath).
|
||||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||||
|
|
||||||
h.nodeNotifier.NotifyAll()
|
h.nodeNotifier.NotifyAll(types.StateUpdate{
|
||||||
|
Type: types.StateFullUpdate,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -218,7 +219,10 @@ func (hsdb *HSDatabase) SetTags(
|
||||||
}
|
}
|
||||||
machine.ForcedTags = newTags
|
machine.ForcedTags = newTags
|
||||||
|
|
||||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
Changed: []uint64{machine.ID},
|
||||||
|
}, machine.MachineKey)
|
||||||
|
|
||||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
|
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
|
||||||
|
@ -232,7 +236,10 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
machine.Expiry = &now
|
machine.Expiry = &now
|
||||||
|
|
||||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
Changed: []uint64{machine.ID},
|
||||||
|
}, machine.MachineKey)
|
||||||
|
|
||||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf("failed to expire machine in the database: %w", err)
|
return fmt.Errorf("failed to expire machine in the database: %w", err)
|
||||||
|
@ -259,7 +266,10 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er
|
||||||
}
|
}
|
||||||
machine.GivenName = newName
|
machine.GivenName = newName
|
||||||
|
|
||||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
Changed: []uint64{machine.ID},
|
||||||
|
}, machine.MachineKey)
|
||||||
|
|
||||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
||||||
|
@ -275,7 +285,10 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time)
|
||||||
machine.LastSuccessfulUpdate = &now
|
machine.LastSuccessfulUpdate = &now
|
||||||
machine.Expiry = &expiry
|
machine.Expiry = &expiry
|
||||||
|
|
||||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
Changed: []uint64{machine.ID},
|
||||||
|
}, machine.MachineKey)
|
||||||
|
|
||||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
|
@ -549,6 +562,27 @@ func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool {
|
||||||
|
ret := make(map[tailcfg.NodeID]bool)
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
ret[tailcfg.NodeID(peer.ID)] = peer.IsOnline()
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) ListOnlineMachines(
|
||||||
|
machine *types.Machine,
|
||||||
|
) (map[tailcfg.NodeID]bool, error) {
|
||||||
|
peers, err := hsdb.ListPeers(machine)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return OnlineMachineMap(peers), nil
|
||||||
|
}
|
||||||
|
|
||||||
// enableRoutes enables new routes based on a list of new routes.
|
// enableRoutes enables new routes based on a list of new routes.
|
||||||
func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error {
|
func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error {
|
||||||
newRoutes := make([]netip.Prefix, len(routeStrs))
|
newRoutes := make([]netip.Prefix, len(routeStrs))
|
||||||
|
@ -600,7 +634,10 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
Changed: []uint64{machine.ID},
|
||||||
|
}, machine.MachineKey)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -676,12 +713,13 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
expiredFound := false
|
expired := make([]tailcfg.NodeID, 0)
|
||||||
for idx, machine := range machines {
|
for idx, machine := range machines {
|
||||||
if machine.IsEphemeral() && machine.LastSeen != nil &&
|
if machine.IsEphemeral() && machine.LastSeen != nil &&
|
||||||
time.Now().
|
time.Now().
|
||||||
After(machine.LastSeen.Add(inactivityThreshhold)) {
|
After(machine.LastSeen.Add(inactivityThreshhold)) {
|
||||||
expiredFound = true
|
expired = append(expired, tailcfg.NodeID(machine.ID))
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
Msg("Ephemeral client removed from database")
|
Msg("Ephemeral client removed from database")
|
||||||
|
@ -696,8 +734,11 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if expiredFound {
|
if len(expired) > 0 {
|
||||||
hsdb.notifier.NotifyAll()
|
hsdb.notifier.NotifyAll(types.StateUpdate{
|
||||||
|
Type: types.StatePeerRemoved,
|
||||||
|
Removed: expired,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -726,11 +767,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
|
||||||
return time.Unix(0, 0)
|
return time.Unix(0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
expiredFound := false
|
expired := make([]tailcfg.NodeID, 0)
|
||||||
for index, machine := range machines {
|
for index, machine := range machines {
|
||||||
if machine.IsExpired() &&
|
if machine.IsExpired() &&
|
||||||
machine.Expiry.After(lastCheck) {
|
machine.Expiry.After(lastCheck) {
|
||||||
expiredFound = true
|
expired = append(expired, tailcfg.NodeID(machine.ID))
|
||||||
|
|
||||||
err := hsdb.ExpireMachine(&machines[index])
|
err := hsdb.ExpireMachine(&machines[index])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -748,8 +789,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if expiredFound {
|
if len(expired) > 0 {
|
||||||
hsdb.notifier.NotifyAll()
|
hsdb.notifier.NotifyAll(types.StateUpdate{
|
||||||
|
Type: types.StatePeerRemoved,
|
||||||
|
Removed: expired,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -274,7 +274,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
||||||
log.Error().Err(err).Msg("error getting routes")
|
log.Error().Err(err).Msg("error getting routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
routesChanged := false
|
changedMachines := make([]uint64, 0)
|
||||||
for pos, route := range routes {
|
for pos, route := range routes {
|
||||||
if route.IsExitRoute() {
|
if route.IsExitRoute() {
|
||||||
continue
|
continue
|
||||||
|
@ -295,7 +295,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
routesChanged = true
|
changedMachines = append(changedMachines, route.MachineID)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -369,12 +369,15 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
routesChanged = true
|
changedMachines = append(changedMachines, route.MachineID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if routesChanged {
|
if len(changedMachines) > 0 {
|
||||||
hsdb.notifier.NotifyAll()
|
hsdb.notifier.NotifyAll(types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
Changed: changedMachines,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -129,45 +130,35 @@ func fullMapResponse(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Peers is always returned sorted by Node.ID.
|
||||||
|
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||||
|
return tailPeers[x].ID < tailPeers[y].ID
|
||||||
|
})
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
resp := tailcfg.MapResponse{
|
resp := tailcfg.MapResponse{
|
||||||
KeepAlive: false,
|
|
||||||
Node: tailnode,
|
Node: tailnode,
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
DERPMap: derpMap,
|
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
Peers: tailPeers,
|
Peers: tailPeers,
|
||||||
|
|
||||||
// TODO(kradalby): Implement:
|
DERPMap: derpMap,
|
||||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
|
|
||||||
// PeersChanged
|
|
||||||
// PeersRemoved
|
|
||||||
// PeersChangedPatch
|
|
||||||
// PeerSeenChange
|
|
||||||
// OnlineChange
|
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
DNSConfig: dnsConfig,
|
DNSConfig: dnsConfig,
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
Domain: baseDomain,
|
Domain: baseDomain,
|
||||||
|
|
||||||
// Do not instruct clients to collect services, we do not
|
// Do not instruct clients to collect services we do not
|
||||||
// support or do anything with them
|
// support or do anything with them
|
||||||
CollectServices: "false",
|
CollectServices: "false",
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
PacketFilter: policy.ReduceFilterRules(machine, rules),
|
PacketFilter: policy.ReduceFilterRules(machine, rules),
|
||||||
|
|
||||||
UserProfiles: profiles,
|
UserProfiles: profiles,
|
||||||
|
|
||||||
// TODO: Only send if updated
|
|
||||||
SSHPolicy: sshPolicy,
|
SSHPolicy: sshPolicy,
|
||||||
|
|
||||||
ControlTime: &now,
|
ControlTime: &now,
|
||||||
|
KeepAlive: false,
|
||||||
|
OnlineChange: db.OnlineMachineMap(peers),
|
||||||
|
|
||||||
Debug: &tailcfg.Debug{
|
Debug: &tailcfg.Debug{
|
||||||
DisableLogTail: !logtail,
|
DisableLogTail: !logtail,
|
||||||
|
@ -271,8 +262,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMapResponse returns a MapResponse for the given machine.
|
// FullMapResponse returns a MapResponse for the given machine.
|
||||||
func (m Mapper) CreateMapResponse(
|
func (m Mapper) FullMapResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
|
@ -302,39 +293,107 @@ func (m Mapper) CreateMapResponse(
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.isNoise {
|
if m.isNoise {
|
||||||
return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress)
|
return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress)
|
||||||
}
|
}
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress)
|
||||||
err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot parse client key")
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Mapper) CreateKeepAliveResponse(
|
func (m Mapper) KeepAliveResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
keepAliveResponse := tailcfg.MapResponse{
|
resp := m.baseMapResponse(machine)
|
||||||
KeepAlive: true,
|
resp.KeepAlive = true
|
||||||
|
|
||||||
|
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Mapper) DERPMapResponse(
|
||||||
|
mapRequest tailcfg.MapRequest,
|
||||||
|
machine *types.Machine,
|
||||||
|
derpMap tailcfg.DERPMap,
|
||||||
|
) ([]byte, error) {
|
||||||
|
resp := m.baseMapResponse(machine)
|
||||||
|
resp.DERPMap = &derpMap
|
||||||
|
|
||||||
|
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Mapper) PeerChangedResponse(
|
||||||
|
mapRequest tailcfg.MapRequest,
|
||||||
|
machine *types.Machine,
|
||||||
|
machineKeys []uint64,
|
||||||
|
pol *policy.ACLPolicy,
|
||||||
|
) ([]byte, error) {
|
||||||
|
var err error
|
||||||
|
changed := make(types.Machines, len(machineKeys))
|
||||||
|
lastSeen := make(map[tailcfg.NodeID]bool)
|
||||||
|
for idx, machineKey := range machineKeys {
|
||||||
|
peer, err := m.db.GetMachineByID(machineKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.isNoise {
|
changed[idx] = *peer
|
||||||
return m.marshalMapResponse(
|
|
||||||
keepAliveResponse,
|
// We have just seen the node, let the peers update their list.
|
||||||
key.MachinePublic{},
|
lastSeen[tailcfg.NodeID(peer.ID)] = true
|
||||||
mapRequest.Compress,
|
}
|
||||||
|
|
||||||
|
rules, _, err := policy.GenerateFilterAndSSHRules(
|
||||||
|
pol,
|
||||||
|
machine,
|
||||||
|
changed,
|
||||||
)
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Filter out peers that have expired.
|
||||||
|
changed = lo.Filter(changed, func(item types.Machine, index int) bool {
|
||||||
|
return !item.IsExpired()
|
||||||
|
})
|
||||||
|
|
||||||
|
// If there are filter rules present, see if there are any machines that cannot
|
||||||
|
// access eachother at all and remove them from the changed.
|
||||||
|
if len(rules) > 0 {
|
||||||
|
changed = policy.FilterMachinesByACL(machine, changed, rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peers is always returned sorted by Node.ID.
|
||||||
|
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||||
|
return tailPeers[x].ID < tailPeers[y].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
resp := m.baseMapResponse(machine)
|
||||||
|
resp.PeersChanged = tailPeers
|
||||||
|
resp.PeerSeenChange = lastSeen
|
||||||
|
|
||||||
|
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Mapper) PeerRemovedResponse(
|
||||||
|
mapRequest tailcfg.MapRequest,
|
||||||
|
machine *types.Machine,
|
||||||
|
removed []tailcfg.NodeID,
|
||||||
|
) ([]byte, error) {
|
||||||
|
resp := m.baseMapResponse(machine)
|
||||||
|
resp.PeersRemoved = removed
|
||||||
|
|
||||||
|
return m.marshalMapResponse(&resp, machine, mapRequest.Compress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Mapper) marshalMapResponse(
|
||||||
|
resp *tailcfg.MapResponse,
|
||||||
|
machine *types.Machine,
|
||||||
|
compression string,
|
||||||
|
) ([]byte, error) {
|
||||||
var machineKey key.MachinePublic
|
var machineKey key.MachinePublic
|
||||||
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -346,40 +405,6 @@ func (m Mapper) CreateKeepAliveResponse(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalResponse takes an Tailscale Response, marhsal it to JSON.
|
|
||||||
// If isNoise is set, then the JSON body will be returned
|
|
||||||
// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
|
|
||||||
func MarshalResponse(
|
|
||||||
resp interface{},
|
|
||||||
isNoise bool,
|
|
||||||
privateKey2019 *key.MachinePrivate,
|
|
||||||
machineKey key.MachinePublic,
|
|
||||||
) ([]byte, error) {
|
|
||||||
jsonBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot marshal response")
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isNoise && privateKey2019 != nil {
|
|
||||||
return privateKey2019.SealTo(machineKey, jsonBody), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return jsonBody, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Mapper) marshalMapResponse(
|
|
||||||
resp interface{},
|
|
||||||
machineKey key.MachinePublic,
|
|
||||||
compression string,
|
|
||||||
) ([]byte, error) {
|
|
||||||
jsonBody, err := json.Marshal(resp)
|
jsonBody, err := json.Marshal(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -409,6 +434,32 @@ func (m Mapper) marshalMapResponse(
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalResponse takes an Tailscale Response, marhsal it to JSON.
|
||||||
|
// If isNoise is set, then the JSON body will be returned
|
||||||
|
// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
|
||||||
|
func MarshalResponse(
|
||||||
|
resp interface{},
|
||||||
|
isNoise bool,
|
||||||
|
privateKey2019 *key.MachinePrivate,
|
||||||
|
machineKey key.MachinePublic,
|
||||||
|
) ([]byte, error) {
|
||||||
|
jsonBody, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot marshal response")
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isNoise && privateKey2019 != nil {
|
||||||
|
return privateKey2019.SealTo(machineKey, jsonBody), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return jsonBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
func zstdEncode(in []byte) []byte {
|
func zstdEncode(in []byte) []byte {
|
||||||
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -433,3 +484,19 @@ var zstdEncoderPool = &sync.Pool{
|
||||||
return encoder
|
return encoder
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Mapper) baseMapResponse(machine *types.Machine) tailcfg.MapResponse {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
resp := tailcfg.MapResponse{
|
||||||
|
KeepAlive: false,
|
||||||
|
ControlTime: &now,
|
||||||
|
}
|
||||||
|
|
||||||
|
online, err := m.db.ListOnlineMachines(machine)
|
||||||
|
if err == nil {
|
||||||
|
resp.OnlineChange = online
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
|
@ -387,6 +387,7 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
DNSConfig: &tailcfg.DNSConfig{},
|
DNSConfig: &tailcfg.DNSConfig{},
|
||||||
Domain: "",
|
Domain: "",
|
||||||
CollectServices: "false",
|
CollectServices: "false",
|
||||||
|
OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false},
|
||||||
PacketFilter: []tailcfg.FilterRule{},
|
PacketFilter: []tailcfg.FilterRule{},
|
||||||
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
|
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
|
||||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||||
|
@ -428,6 +429,7 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
DNSConfig: &tailcfg.DNSConfig{},
|
DNSConfig: &tailcfg.DNSConfig{},
|
||||||
Domain: "",
|
Domain: "",
|
||||||
CollectServices: "false",
|
CollectServices: "false",
|
||||||
|
OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false},
|
||||||
PacketFilter: []tailcfg.FilterRule{
|
PacketFilter: []tailcfg.FilterRule{
|
||||||
{
|
{
|
||||||
SrcIPs: []string{"100.64.0.2/32"},
|
SrcIPs: []string{"100.64.0.2/32"},
|
||||||
|
|
|
@ -3,24 +3,25 @@ package notifier
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Notifier struct {
|
type Notifier struct {
|
||||||
l sync.RWMutex
|
l sync.RWMutex
|
||||||
nodes map[string]chan<- struct{}
|
nodes map[string]chan<- types.StateUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNotifier() *Notifier {
|
func NewNotifier() *Notifier {
|
||||||
return &Notifier{}
|
return &Notifier{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) AddNode(machineKey string, c chan<- struct{}) {
|
func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) {
|
||||||
n.l.Lock()
|
n.l.Lock()
|
||||||
defer n.l.Unlock()
|
defer n.l.Unlock()
|
||||||
|
|
||||||
if n.nodes == nil {
|
if n.nodes == nil {
|
||||||
n.nodes = make(map[string]chan<- struct{})
|
n.nodes = make(map[string]chan<- types.StateUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
n.nodes[machineKey] = c
|
n.nodes[machineKey] = c
|
||||||
|
@ -37,11 +38,11 @@ func (n *Notifier) RemoveNode(machineKey string) {
|
||||||
delete(n.nodes, machineKey)
|
delete(n.nodes, machineKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) NotifyAll() {
|
func (n *Notifier) NotifyAll(update types.StateUpdate) {
|
||||||
n.NotifyWithIgnore()
|
n.NotifyWithIgnore(update)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) NotifyWithIgnore(ignore ...string) {
|
func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) {
|
||||||
n.l.RLock()
|
n.l.RLock()
|
||||||
defer n.l.RUnlock()
|
defer n.l.RUnlock()
|
||||||
|
|
||||||
|
@ -50,6 +51,6 @@ func (n *Notifier) NotifyWithIgnore(ignore ...string) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
c <- struct{}{}
|
c <- update
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,7 +116,7 @@ func (h *Headscale) handlePoll(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logErr(err, "Failed to create MapResponse")
|
logErr(err, "Failed to create MapResponse")
|
||||||
http.Error(writer, "", http.StatusInternalServerError)
|
http.Error(writer, "", http.StatusInternalServerError)
|
||||||
|
@ -163,7 +163,12 @@ func (h *Headscale) handlePoll(
|
||||||
Inc()
|
Inc()
|
||||||
|
|
||||||
// Tell all the other nodes about the new endpoint, but dont update ourselves.
|
// Tell all the other nodes about the new endpoint, but dont update ourselves.
|
||||||
h.nodeNotifier.NotifyWithIgnore(machine.MachineKey)
|
h.nodeNotifier.NotifyWithIgnore(
|
||||||
|
types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
Changed: []uint64{machine.ID},
|
||||||
|
},
|
||||||
|
machine.MachineKey)
|
||||||
|
|
||||||
return
|
return
|
||||||
} else if mapRequest.OmitPeers && mapRequest.Stream {
|
} else if mapRequest.OmitPeers && mapRequest.Stream {
|
||||||
|
@ -220,7 +225,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||||
|
|
||||||
const chanSize = 8
|
const chanSize = 8
|
||||||
updateChan := make(chan struct{}, chanSize)
|
updateChan := make(chan types.StateUpdate, chanSize)
|
||||||
|
|
||||||
h.pollNetMapStreamWG.Add(1)
|
h.pollNetMapStreamWG.Add(1)
|
||||||
defer h.pollNetMapStreamWG.Done()
|
defer h.pollNetMapStreamWG.Done()
|
||||||
|
@ -238,7 +243,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-keepAliveTicker.C:
|
case <-keepAliveTicker.C:
|
||||||
data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
|
data, err := mapp.KeepAliveResponse(mapRequest, machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logErr(err, "Error generating the keep alive msg")
|
logErr(err, "Error generating the keep alive msg")
|
||||||
|
|
||||||
|
@ -263,10 +268,23 @@ func (h *Headscale) pollNetMapStream(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-updateChan:
|
case update := <-updateChan:
|
||||||
data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
var data []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch update.Type {
|
||||||
|
case types.StateFullUpdate:
|
||||||
|
data, err = mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||||
|
case types.StatePeerChanged:
|
||||||
|
data, err = mapp.PeerChangedResponse(mapRequest, machine, update.Changed, h.ACLPolicy)
|
||||||
|
case types.StatePeerRemoved:
|
||||||
|
data, err = mapp.PeerRemovedResponse(mapRequest, machine, update.Removed)
|
||||||
|
case types.StateDERPUpdated:
|
||||||
|
data, err = mapp.DERPMapResponse(mapRequest, machine, update.DERPMap)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logErr(err, "Could not get the map update")
|
logErr(err, "Could not get the create map update")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -317,7 +335,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) {
|
func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, machine, name string) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Str("machine", machine).
|
Str("machine", machine).
|
||||||
|
|
|
@ -106,3 +106,32 @@ func (i StringList) Value() (driver.Value, error) {
|
||||||
|
|
||||||
return string(bytes), err
|
return string(bytes), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StateUpdateType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateFullUpdate StateUpdateType = iota
|
||||||
|
StatePeerChanged
|
||||||
|
StatePeerRemoved
|
||||||
|
StateDERPUpdated
|
||||||
|
)
|
||||||
|
|
||||||
|
// StateUpdate is an internal message containing information about
|
||||||
|
// a state change that has happened to the network.
|
||||||
|
type StateUpdate struct {
|
||||||
|
// The type of update
|
||||||
|
Type StateUpdateType
|
||||||
|
|
||||||
|
// Changed must be set when Type is StatePeerChanged and
|
||||||
|
// contain the Machine IDs of machines that has changed.
|
||||||
|
Changed []uint64
|
||||||
|
|
||||||
|
// Removed must be set when Type is StatePeerRemoved and
|
||||||
|
// contain a list of the nodes that has been removed from
|
||||||
|
// the network.
|
||||||
|
Removed []tailcfg.NodeID
|
||||||
|
|
||||||
|
// DERPMap must be set when Type is StateDERPUpdated and
|
||||||
|
// contain the new DERP Map.
|
||||||
|
DERPMap tailcfg.DERPMap
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue