move MapResponse peer logic into function and reuse
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
387aa03adb
commit
432e975a7f
7 changed files with 193 additions and 173 deletions
|
@ -92,6 +92,8 @@ type Headscale struct {
|
||||||
|
|
||||||
shutdownChan chan struct{}
|
shutdownChan chan struct{}
|
||||||
pollNetMapStreamWG sync.WaitGroup
|
pollNetMapStreamWG sync.WaitGroup
|
||||||
|
|
||||||
|
pollStreamOpenMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
|
|
|
@ -340,6 +340,8 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
machine := &route.Machine
|
||||||
|
|
||||||
if !route.IsPrimary {
|
if !route.IsPrimary {
|
||||||
_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
|
_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
|
||||||
if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
|
if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
@ -355,7 +357,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
changedMachines = append(changedMachines, &route.Machine)
|
changedMachines = append(changedMachines, machine)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -429,7 +431,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
changedMachines = append(changedMachines, &route.Machine)
|
changedMachines = append(changedMachines, machine)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,16 @@ const (
|
||||||
|
|
||||||
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
|
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
|
||||||
|
|
||||||
|
// TODO: Optimise
|
||||||
|
// As this work continues, the idea is that there will be one Mapper instance
|
||||||
|
// per node, attached to the open stream between the control and client.
|
||||||
|
// This means that this can hold a state per machine and we can use that to
|
||||||
|
// improve the mapresponses sent.
|
||||||
|
// We could:
|
||||||
|
// - Keep information about the previous mapresponse so we can send a diff
|
||||||
|
// - Store hashes
|
||||||
|
// - Create a "minifier" that removes info not needed for the node
|
||||||
|
|
||||||
type Mapper struct {
|
type Mapper struct {
|
||||||
privateKey2019 *key.MachinePrivate
|
privateKey2019 *key.MachinePrivate
|
||||||
isNoise bool
|
isNoise bool
|
||||||
|
@ -102,105 +112,6 @@ func (m *Mapper) String() string {
|
||||||
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
|
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Optimise
|
|
||||||
// As this work continues, the idea is that there will be one Mapper instance
|
|
||||||
// per node, attached to the open stream between the control and client.
|
|
||||||
// This means that this can hold a state per machine and we can use that to
|
|
||||||
// improve the mapresponses sent.
|
|
||||||
// We could:
|
|
||||||
// - Keep information about the previous mapresponse so we can send a diff
|
|
||||||
// - Store hashes
|
|
||||||
// - Create a "minifier" that removes info not needed for the node
|
|
||||||
|
|
||||||
// fullMapResponse is the internal function for generating a MapResponse
|
|
||||||
// for a machine.
|
|
||||||
func fullMapResponse(
|
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
machine *types.Machine,
|
|
||||||
peers types.Machines,
|
|
||||||
|
|
||||||
baseDomain string,
|
|
||||||
dnsCfg *tailcfg.DNSConfig,
|
|
||||||
derpMap *tailcfg.DERPMap,
|
|
||||||
logtail bool,
|
|
||||||
randomClientPort bool,
|
|
||||||
) (*tailcfg.MapResponse, error) {
|
|
||||||
tailnode, err := tailNode(machine, pol, dnsCfg, baseDomain)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
resp := tailcfg.MapResponse{
|
|
||||||
Node: tailnode,
|
|
||||||
|
|
||||||
DERPMap: derpMap,
|
|
||||||
|
|
||||||
Domain: baseDomain,
|
|
||||||
|
|
||||||
// Do not instruct clients to collect services we do not
|
|
||||||
// support or do anything with them
|
|
||||||
CollectServices: "false",
|
|
||||||
|
|
||||||
ControlTime: &now,
|
|
||||||
KeepAlive: false,
|
|
||||||
OnlineChange: db.OnlineMachineMap(peers),
|
|
||||||
|
|
||||||
Debug: &tailcfg.Debug{
|
|
||||||
DisableLogTail: !logtail,
|
|
||||||
RandomizeClientPort: randomClientPort,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if peers != nil || len(peers) > 0 {
|
|
||||||
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
|
||||||
pol,
|
|
||||||
machine,
|
|
||||||
peers,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter out peers that have expired.
|
|
||||||
peers = filterExpiredAndNotReady(peers)
|
|
||||||
|
|
||||||
// If there are filter rules present, see if there are any machines that cannot
|
|
||||||
// access eachother at all and remove them from the peers.
|
|
||||||
if len(rules) > 0 {
|
|
||||||
peers = policy.FilterMachinesByACL(machine, peers, rules)
|
|
||||||
}
|
|
||||||
|
|
||||||
profiles := generateUserProfiles(machine, peers, baseDomain)
|
|
||||||
|
|
||||||
dnsConfig := generateDNSConfig(
|
|
||||||
dnsCfg,
|
|
||||||
baseDomain,
|
|
||||||
machine,
|
|
||||||
peers,
|
|
||||||
)
|
|
||||||
|
|
||||||
tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Peers is always returned sorted by Node.ID.
|
|
||||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
|
||||||
return tailPeers[x].ID < tailPeers[y].ID
|
|
||||||
})
|
|
||||||
|
|
||||||
resp.Peers = tailPeers
|
|
||||||
resp.DNSConfig = dnsConfig
|
|
||||||
resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
|
|
||||||
resp.UserProfiles = profiles
|
|
||||||
resp.SSHPolicy = sshPolicy
|
|
||||||
}
|
|
||||||
|
|
||||||
return &resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateUserProfiles(
|
func generateUserProfiles(
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
peers types.Machines,
|
peers types.Machines,
|
||||||
|
@ -294,6 +205,38 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine *types.Machine) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fullMapResponse creates a complete MapResponse for a node.
|
||||||
|
// It is a separate function to make testing easier.
|
||||||
|
func (m *Mapper) fullMapResponse(
|
||||||
|
machine *types.Machine,
|
||||||
|
pol *policy.ACLPolicy,
|
||||||
|
) (*tailcfg.MapResponse, error) {
|
||||||
|
peers := machineMapToList(m.peers)
|
||||||
|
|
||||||
|
resp, err := m.baseWithConfigMapResponse(machine, pol)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Move this into appendPeerChanges?
|
||||||
|
resp.OnlineChange = db.OnlineMachineMap(peers)
|
||||||
|
|
||||||
|
err = appendPeerChanges(
|
||||||
|
resp,
|
||||||
|
pol,
|
||||||
|
machine,
|
||||||
|
peers,
|
||||||
|
peers,
|
||||||
|
m.baseDomain,
|
||||||
|
m.dnsCfg,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
// FullMapResponse returns a MapResponse for the given machine.
|
// FullMapResponse returns a MapResponse for the given machine.
|
||||||
func (m *Mapper) FullMapResponse(
|
func (m *Mapper) FullMapResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
|
@ -303,25 +246,16 @@ func (m *Mapper) FullMapResponse(
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
mapResponse, err := fullMapResponse(
|
resp, err := m.fullMapResponse(machine, pol)
|
||||||
pol,
|
|
||||||
machine,
|
|
||||||
machineMapToList(m.peers),
|
|
||||||
m.baseDomain,
|
|
||||||
m.dnsCfg,
|
|
||||||
m.derpMap,
|
|
||||||
m.logtail,
|
|
||||||
m.randomClientPort,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.isNoise {
|
if m.isNoise {
|
||||||
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
|
return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
|
return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LiteMapResponse returns a MapResponse for the given machine.
|
// LiteMapResponse returns a MapResponse for the given machine.
|
||||||
|
@ -332,32 +266,23 @@ func (m *Mapper) LiteMapResponse(
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
mapResponse, err := fullMapResponse(
|
resp, err := m.baseWithConfigMapResponse(machine, pol)
|
||||||
pol,
|
|
||||||
machine,
|
|
||||||
nil,
|
|
||||||
m.baseDomain,
|
|
||||||
m.dnsCfg,
|
|
||||||
m.derpMap,
|
|
||||||
m.logtail,
|
|
||||||
m.randomClientPort,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.isNoise {
|
if m.isNoise {
|
||||||
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
|
return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
|
return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mapper) KeepAliveResponse(
|
func (m *Mapper) KeepAliveResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
resp := m.baseMapResponse(machine)
|
resp := m.baseMapResponse()
|
||||||
resp.KeepAlive = true
|
resp.KeepAlive = true
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
||||||
|
@ -368,7 +293,7 @@ func (m *Mapper) DERPMapResponse(
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
derpMap tailcfg.DERPMap,
|
derpMap tailcfg.DERPMap,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
resp := m.baseMapResponse(machine)
|
resp := m.baseMapResponse()
|
||||||
resp.DERPMap = &derpMap
|
resp.DERPMap = &derpMap
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
||||||
|
@ -383,7 +308,6 @@ func (m *Mapper) PeerChangedResponse(
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
var err error
|
|
||||||
lastSeen := make(map[tailcfg.NodeID]bool)
|
lastSeen := make(map[tailcfg.NodeID]bool)
|
||||||
|
|
||||||
// Update our internal map.
|
// Update our internal map.
|
||||||
|
@ -394,37 +318,21 @@ func (m *Mapper) PeerChangedResponse(
|
||||||
lastSeen[tailcfg.NodeID(machine.ID)] = true
|
lastSeen[tailcfg.NodeID(machine.ID)] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
resp := m.baseMapResponse()
|
||||||
|
|
||||||
|
err := appendPeerChanges(
|
||||||
|
&resp,
|
||||||
pol,
|
pol,
|
||||||
machine,
|
machine,
|
||||||
machineMapToList(m.peers),
|
machineMapToList(m.peers),
|
||||||
|
changed,
|
||||||
|
m.baseDomain,
|
||||||
|
m.dnsCfg,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
changed = filterExpiredAndNotReady(changed)
|
|
||||||
|
|
||||||
// If there are filter rules present, see if there are any machines that cannot
|
|
||||||
// access eachother at all and remove them from the changed.
|
|
||||||
if len(rules) > 0 {
|
|
||||||
changed = policy.FilterMachinesByACL(machine, changed, rules)
|
|
||||||
}
|
|
||||||
|
|
||||||
tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Peers is always returned sorted by Node.ID.
|
|
||||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
|
||||||
return tailPeers[x].ID < tailPeers[y].ID
|
|
||||||
})
|
|
||||||
|
|
||||||
resp := m.baseMapResponse(machine)
|
|
||||||
resp.PeersChanged = tailPeers
|
|
||||||
resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
|
|
||||||
resp.SSHPolicy = sshPolicy
|
|
||||||
// resp.PeerSeenChange = lastSeen
|
// resp.PeerSeenChange = lastSeen
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
||||||
|
@ -443,7 +351,7 @@ func (m *Mapper) PeerRemovedResponse(
|
||||||
delete(m.peers, uint64(id))
|
delete(m.peers, uint64(id))
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := m.baseMapResponse(machine)
|
resp := m.baseMapResponse()
|
||||||
resp.PeersRemoved = removed
|
resp.PeersRemoved = removed
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
|
||||||
|
@ -497,7 +405,7 @@ func (m *Mapper) marshalMapResponse(
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().Unix()
|
now := time.Now().UnixNano()
|
||||||
|
|
||||||
mapResponsePath := path.Join(
|
mapResponsePath := path.Join(
|
||||||
mPath,
|
mPath,
|
||||||
|
@ -583,7 +491,9 @@ var zstdEncoderPool = &sync.Pool{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse {
|
// baseMapResponse returns a tailcfg.MapResponse with
|
||||||
|
// KeepAlive false and ControlTime set to now.
|
||||||
|
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
resp := tailcfg.MapResponse{
|
resp := tailcfg.MapResponse{
|
||||||
|
@ -591,14 +501,43 @@ func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse {
|
||||||
ControlTime: &now,
|
ControlTime: &now,
|
||||||
}
|
}
|
||||||
|
|
||||||
// online, err := m.db.ListOnlineMachines(machine)
|
|
||||||
// if err == nil {
|
|
||||||
// resp.OnlineChange = online
|
|
||||||
// }
|
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
|
||||||
|
// with the basic configuration from headscale set.
|
||||||
|
// It is used in for bigger updates, such as full and lite, not
|
||||||
|
// incremental.
|
||||||
|
func (m *Mapper) baseWithConfigMapResponse(
|
||||||
|
machine *types.Machine,
|
||||||
|
pol *policy.ACLPolicy,
|
||||||
|
) (*tailcfg.MapResponse, error) {
|
||||||
|
resp := m.baseMapResponse()
|
||||||
|
|
||||||
|
tailnode, err := tailNode(machine, pol, m.dnsCfg, m.baseDomain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp.Node = tailnode
|
||||||
|
|
||||||
|
resp.DERPMap = m.derpMap
|
||||||
|
|
||||||
|
resp.Domain = m.baseDomain
|
||||||
|
|
||||||
|
// Do not instruct clients to collect services we do not
|
||||||
|
// support or do anything with them
|
||||||
|
resp.CollectServices = "false"
|
||||||
|
|
||||||
|
resp.KeepAlive = false
|
||||||
|
|
||||||
|
resp.Debug = &tailcfg.Debug{
|
||||||
|
DisableLogTail: !m.logtail,
|
||||||
|
RandomizeClientPort: m.randomClientPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
func machineMapToList(machines map[uint64]*types.Machine) types.Machines {
|
func machineMapToList(machines map[uint64]*types.Machine) types.Machines {
|
||||||
ret := make(types.Machines, 0)
|
ret := make(types.Machines, 0)
|
||||||
|
|
||||||
|
@ -617,3 +556,67 @@ func filterExpiredAndNotReady(peers types.Machines) types.Machines {
|
||||||
return !item.IsExpired() || len(item.Endpoints) > 0
|
return !item.IsExpired() || len(item.Endpoints) > 0
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// appendPeerChanges mutates a tailcfg.MapResponse with all the
|
||||||
|
// necessary changes when peers have changed.
|
||||||
|
func appendPeerChanges(
|
||||||
|
resp *tailcfg.MapResponse,
|
||||||
|
|
||||||
|
pol *policy.ACLPolicy,
|
||||||
|
machine *types.Machine,
|
||||||
|
peers types.Machines,
|
||||||
|
changed types.Machines,
|
||||||
|
baseDomain string,
|
||||||
|
dnsCfg *tailcfg.DNSConfig,
|
||||||
|
) error {
|
||||||
|
fullChange := len(peers) == len(changed)
|
||||||
|
|
||||||
|
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
||||||
|
pol,
|
||||||
|
machine,
|
||||||
|
peers,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out peers that have expired.
|
||||||
|
changed = filterExpiredAndNotReady(changed)
|
||||||
|
|
||||||
|
// If there are filter rules present, see if there are any machines that cannot
|
||||||
|
// access eachother at all and remove them from the peers.
|
||||||
|
if len(rules) > 0 {
|
||||||
|
changed = policy.FilterMachinesByACL(machine, changed, rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
profiles := generateUserProfiles(machine, changed, baseDomain)
|
||||||
|
|
||||||
|
dnsConfig := generateDNSConfig(
|
||||||
|
dnsCfg,
|
||||||
|
baseDomain,
|
||||||
|
machine,
|
||||||
|
peers,
|
||||||
|
)
|
||||||
|
|
||||||
|
tailPeers, err := tailNodes(changed, pol, dnsCfg, baseDomain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peers is always returned sorted by Node.ID.
|
||||||
|
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||||
|
return tailPeers[x].ID < tailPeers[y].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
if fullChange {
|
||||||
|
resp.Peers = tailPeers
|
||||||
|
} else {
|
||||||
|
resp.PeersChanged = tailPeers
|
||||||
|
}
|
||||||
|
resp.DNSConfig = dnsConfig
|
||||||
|
resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
|
||||||
|
resp.UserProfiles = profiles
|
||||||
|
resp.SSHPolicy = sshPolicy
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -441,7 +441,9 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
|
UserProfiles: []tailcfg.UserProfile{
|
||||||
|
{LoginName: "mini", DisplayName: "mini"},
|
||||||
|
},
|
||||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||||
ControlTime: &time.Time{},
|
ControlTime: &time.Time{},
|
||||||
Debug: &tailcfg.Debug{
|
Debug: &tailcfg.Debug{
|
||||||
|
@ -454,17 +456,23 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := fullMapResponse(
|
mappy := NewMapper(
|
||||||
tt.pol,
|
|
||||||
tt.machine,
|
tt.machine,
|
||||||
tt.peers,
|
tt.peers,
|
||||||
|
nil,
|
||||||
|
false,
|
||||||
|
tt.derpMap,
|
||||||
tt.baseDomain,
|
tt.baseDomain,
|
||||||
tt.dnsConfig,
|
tt.dnsConfig,
|
||||||
tt.derpMap,
|
|
||||||
tt.logtail,
|
tt.logtail,
|
||||||
tt.randomClientPort,
|
tt.randomClientPort,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
got, err := mappy.fullMapResponse(
|
||||||
|
tt.machine,
|
||||||
|
tt.pol,
|
||||||
|
)
|
||||||
|
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,8 @@ func logPollFunc(
|
||||||
|
|
||||||
// handlePoll is the common code for the legacy and Noise protocols to
|
// handlePoll is the common code for the legacy and Noise protocols to
|
||||||
// managed the poll loop.
|
// managed the poll loop.
|
||||||
|
//
|
||||||
|
//nolint:gocyclo
|
||||||
func (h *Headscale) handlePoll(
|
func (h *Headscale) handlePoll(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
@ -67,6 +69,7 @@ func (h *Headscale) handlePoll(
|
||||||
// following updates missing
|
// following updates missing
|
||||||
var updateChan chan types.StateUpdate
|
var updateChan chan types.StateUpdate
|
||||||
if mapRequest.Stream {
|
if mapRequest.Stream {
|
||||||
|
h.pollStreamOpenMu.Lock()
|
||||||
h.pollNetMapStreamWG.Add(1)
|
h.pollNetMapStreamWG.Add(1)
|
||||||
defer h.pollNetMapStreamWG.Done()
|
defer h.pollNetMapStreamWG.Done()
|
||||||
|
|
||||||
|
@ -251,6 +254,8 @@ func (h *Headscale) handlePoll(
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
h.pollStreamOpenMu.Unlock()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
logInfo("Waiting for update on stream channel")
|
logInfo("Waiting for update on stream channel")
|
||||||
select {
|
select {
|
||||||
|
|
|
@ -407,9 +407,8 @@ func TestResolveMagicDNS(t *testing.T) {
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
// Omit 1.16.2 (-1) because it does not have the FQDN field
|
"magicdns1": len(MustTestVersions),
|
||||||
"magicdns1": len(MustTestVersions) - 1,
|
"magicdns2": len(MustTestVersions),
|
||||||
"magicdns2": len(MustTestVersions) - 1,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns"))
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
tsicHashLength = 6
|
tsicHashLength = 6
|
||||||
|
defaultPingTimeout = 300 * time.Millisecond
|
||||||
defaultPingCount = 10
|
defaultPingCount = 10
|
||||||
dockerContextPath = "../."
|
dockerContextPath = "../."
|
||||||
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
|
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
|
||||||
|
@ -591,7 +592,7 @@ func WithPingUntilDirect(direct bool) PingOption {
|
||||||
// TODO(kradalby): Make multiping, go routine magic.
|
// TODO(kradalby): Make multiping, go routine magic.
|
||||||
func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) error {
|
func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) error {
|
||||||
args := pingArgs{
|
args := pingArgs{
|
||||||
timeout: 300 * time.Millisecond,
|
timeout: defaultPingTimeout,
|
||||||
count: defaultPingCount,
|
count: defaultPingCount,
|
||||||
direct: true,
|
direct: true,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue