83769ba715
This commits removes the locks used to guard data integrity for the database and replaces them with Transactions, turns out that SQL had a way to deal with this all along. This reduces the complexity we had with multiple locks that might stack or recurse (database, nofitifer, mapper). All notifications and state updates are now triggered _after_ a database change. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
681 lines
16 KiB
Go
681 lines
16 KiB
Go
package mapper
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/fs"
|
|
"net/url"
|
|
"os"
|
|
"path"
|
|
"slices"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
mapset "github.com/deckarep/golang-set/v2"
|
|
"github.com/juanfont/headscale/hscontrol/policy"
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/util"
|
|
"github.com/klauspost/compress/zstd"
|
|
"github.com/rs/zerolog/log"
|
|
"golang.org/x/exp/maps"
|
|
"tailscale.com/envknob"
|
|
"tailscale.com/smallzstd"
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/types/dnstype"
|
|
)
|
|
|
|
const (
|
|
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
|
reservedResponseHeaderSize = 4
|
|
mapperIDLength = 8
|
|
debugMapResponsePerm = 0o755
|
|
)
|
|
|
|
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
|
|
|
|
// TODO: Optimise
|
|
// As this work continues, the idea is that there will be one Mapper instance
|
|
// per node, attached to the open stream between the control and client.
|
|
// This means that this can hold a state per node and we can use that to
|
|
// improve the mapresponses sent.
|
|
// We could:
|
|
// - Keep information about the previous mapresponse so we can send a diff
|
|
// - Store hashes
|
|
// - Create a "minifier" that removes info not needed for the node
|
|
// - some sort of batching, wait for 5 or 60 seconds before sending
|
|
|
|
type Mapper struct {
|
|
// Configuration
|
|
// TODO(kradalby): figure out if this is the format we want this in
|
|
derpMap *tailcfg.DERPMap
|
|
baseDomain string
|
|
dnsCfg *tailcfg.DNSConfig
|
|
logtail bool
|
|
randomClientPort bool
|
|
|
|
uid string
|
|
created time.Time
|
|
seq uint64
|
|
|
|
// Map isnt concurrency safe, so we need to ensure
|
|
// only one func is accessing it over time.
|
|
mu sync.Mutex
|
|
peers map[uint64]*types.Node
|
|
patches map[uint64][]patch
|
|
}
|
|
|
|
type patch struct {
|
|
timestamp time.Time
|
|
change *tailcfg.PeerChange
|
|
}
|
|
|
|
func NewMapper(
|
|
node *types.Node,
|
|
peers types.Nodes,
|
|
derpMap *tailcfg.DERPMap,
|
|
baseDomain string,
|
|
dnsCfg *tailcfg.DNSConfig,
|
|
logtail bool,
|
|
randomClientPort bool,
|
|
) *Mapper {
|
|
log.Debug().
|
|
Caller().
|
|
Str("node", node.Hostname).
|
|
Msg("creating new mapper")
|
|
|
|
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
|
|
|
return &Mapper{
|
|
derpMap: derpMap,
|
|
baseDomain: baseDomain,
|
|
dnsCfg: dnsCfg,
|
|
logtail: logtail,
|
|
randomClientPort: randomClientPort,
|
|
|
|
uid: uid,
|
|
created: time.Now(),
|
|
seq: 0,
|
|
|
|
// TODO: populate
|
|
peers: peers.IDMap(),
|
|
patches: make(map[uint64][]patch),
|
|
}
|
|
}
|
|
|
|
func (m *Mapper) String() string {
|
|
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
|
|
}
|
|
|
|
func generateUserProfiles(
|
|
node *types.Node,
|
|
peers types.Nodes,
|
|
baseDomain string,
|
|
) []tailcfg.UserProfile {
|
|
userMap := make(map[string]types.User)
|
|
userMap[node.User.Name] = node.User
|
|
for _, peer := range peers {
|
|
userMap[peer.User.Name] = peer.User // not worth checking if already is there
|
|
}
|
|
|
|
profiles := []tailcfg.UserProfile{}
|
|
for _, user := range userMap {
|
|
displayName := user.Name
|
|
|
|
if baseDomain != "" {
|
|
displayName = fmt.Sprintf("%s@%s", user.Name, baseDomain)
|
|
}
|
|
|
|
profiles = append(profiles,
|
|
tailcfg.UserProfile{
|
|
ID: tailcfg.UserID(user.ID),
|
|
LoginName: user.Name,
|
|
DisplayName: displayName,
|
|
})
|
|
}
|
|
|
|
return profiles
|
|
}
|
|
|
|
func generateDNSConfig(
|
|
base *tailcfg.DNSConfig,
|
|
baseDomain string,
|
|
node *types.Node,
|
|
peers types.Nodes,
|
|
) *tailcfg.DNSConfig {
|
|
dnsConfig := base.Clone()
|
|
|
|
// if MagicDNS is enabled
|
|
if base != nil && base.Proxied {
|
|
// Only inject the Search Domain of the current user
|
|
// shared nodes should use their full FQDN
|
|
dnsConfig.Domains = append(
|
|
dnsConfig.Domains,
|
|
fmt.Sprintf(
|
|
"%s.%s",
|
|
node.User.Name,
|
|
baseDomain,
|
|
),
|
|
)
|
|
|
|
userSet := mapset.NewSet[types.User]()
|
|
userSet.Add(node.User)
|
|
for _, p := range peers {
|
|
userSet.Add(p.User)
|
|
}
|
|
for _, user := range userSet.ToSlice() {
|
|
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
|
|
dnsConfig.Routes[dnsRoute] = nil
|
|
}
|
|
} else {
|
|
dnsConfig = base
|
|
}
|
|
|
|
addNextDNSMetadata(dnsConfig.Resolvers, node)
|
|
|
|
return dnsConfig
|
|
}
|
|
|
|
// If any nextdns DoH resolvers are present in the list of resolvers it will
|
|
// take metadata from the node metadata and instruct tailscale to add it
|
|
// to the requests. This makes it possible to identify from which device the
|
|
// requests come in the NextDNS dashboard.
|
|
//
|
|
// This will produce a resolver like:
|
|
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
|
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
|
for _, resolver := range resolvers {
|
|
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
|
attrs := url.Values{
|
|
"device_name": []string{node.Hostname},
|
|
"device_model": []string{node.Hostinfo.OS},
|
|
}
|
|
|
|
if len(node.IPAddresses) > 0 {
|
|
attrs.Add("device_ip", node.IPAddresses[0].String())
|
|
}
|
|
|
|
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
|
}
|
|
}
|
|
}
|
|
|
|
// fullMapResponse creates a complete MapResponse for a node.
|
|
// It is a separate function to make testing easier.
|
|
func (m *Mapper) fullMapResponse(
|
|
node *types.Node,
|
|
pol *policy.ACLPolicy,
|
|
capVer tailcfg.CapabilityVersion,
|
|
) (*tailcfg.MapResponse, error) {
|
|
peers := nodeMapToList(m.peers)
|
|
|
|
resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = appendPeerChanges(
|
|
resp,
|
|
pol,
|
|
node,
|
|
capVer,
|
|
peers,
|
|
peers,
|
|
m.baseDomain,
|
|
m.dnsCfg,
|
|
m.randomClientPort,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// FullMapResponse returns a MapResponse for the given node.
|
|
func (m *Mapper) FullMapResponse(
|
|
mapRequest tailcfg.MapRequest,
|
|
node *types.Node,
|
|
pol *policy.ACLPolicy,
|
|
) ([]byte, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
peers := maps.Keys(m.peers)
|
|
peersWithPatches := maps.Keys(m.patches)
|
|
slices.Sort(peers)
|
|
slices.Sort(peersWithPatches)
|
|
|
|
if len(peersWithPatches) > 0 {
|
|
log.Debug().
|
|
Str("node", node.Hostname).
|
|
Uints64("peers", peers).
|
|
Uints64("pending_patches", peersWithPatches).
|
|
Msgf("node requested full map response, but has pending patches")
|
|
}
|
|
|
|
resp, err := m.fullMapResponse(node, pol, mapRequest.Version)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress)
|
|
}
|
|
|
|
// LiteMapResponse returns a MapResponse for the given node.
|
|
// Lite means that the peers has been omitted, this is intended
|
|
// to be used to answer MapRequests with OmitPeers set to true.
|
|
func (m *Mapper) LiteMapResponse(
|
|
mapRequest tailcfg.MapRequest,
|
|
node *types.Node,
|
|
pol *policy.ACLPolicy,
|
|
messages ...string,
|
|
) ([]byte, error) {
|
|
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
|
pol,
|
|
node,
|
|
nodeMapToList(m.peers),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
|
|
resp.SSHPolicy = sshPolicy
|
|
|
|
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
|
}
|
|
|
|
func (m *Mapper) KeepAliveResponse(
|
|
mapRequest tailcfg.MapRequest,
|
|
node *types.Node,
|
|
) ([]byte, error) {
|
|
resp := m.baseMapResponse()
|
|
resp.KeepAlive = true
|
|
|
|
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
|
}
|
|
|
|
func (m *Mapper) DERPMapResponse(
|
|
mapRequest tailcfg.MapRequest,
|
|
node *types.Node,
|
|
derpMap *tailcfg.DERPMap,
|
|
) ([]byte, error) {
|
|
m.derpMap = derpMap
|
|
|
|
resp := m.baseMapResponse()
|
|
resp.DERPMap = derpMap
|
|
|
|
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
|
}
|
|
|
|
func (m *Mapper) PeerChangedResponse(
|
|
mapRequest tailcfg.MapRequest,
|
|
node *types.Node,
|
|
changed types.Nodes,
|
|
pol *policy.ACLPolicy,
|
|
messages ...string,
|
|
) ([]byte, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Update our internal map.
|
|
for _, node := range changed {
|
|
if patches, ok := m.patches[node.ID]; ok {
|
|
// preserve online status in case the patch has an outdated one
|
|
online := node.IsOnline
|
|
|
|
for _, p := range patches {
|
|
// TODO(kradalby): Figure if this needs to be sorted by timestamp
|
|
node.ApplyPeerChange(p.change)
|
|
}
|
|
|
|
// Ensure the patches are not applied again later
|
|
delete(m.patches, node.ID)
|
|
|
|
node.IsOnline = online
|
|
}
|
|
|
|
m.peers[node.ID] = node
|
|
}
|
|
|
|
resp := m.baseMapResponse()
|
|
|
|
err := appendPeerChanges(
|
|
&resp,
|
|
pol,
|
|
node,
|
|
mapRequest.Version,
|
|
nodeMapToList(m.peers),
|
|
changed,
|
|
m.baseDomain,
|
|
m.dnsCfg,
|
|
m.randomClientPort,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
|
|
}
|
|
|
|
// PeerChangedPatchResponse creates a patch MapResponse with
|
|
// incoming update from a state change.
|
|
func (m *Mapper) PeerChangedPatchResponse(
|
|
mapRequest tailcfg.MapRequest,
|
|
node *types.Node,
|
|
changed []*tailcfg.PeerChange,
|
|
pol *policy.ACLPolicy,
|
|
) ([]byte, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
sendUpdate := false
|
|
// patch the internal map
|
|
for _, change := range changed {
|
|
if peer, ok := m.peers[uint64(change.NodeID)]; ok {
|
|
peer.ApplyPeerChange(change)
|
|
sendUpdate = true
|
|
} else {
|
|
log.Trace().Str("node", node.Hostname).Msgf("Node with ID %s is missing from mapper for Node %s, saving patch for when node is available", change.NodeID, node.Hostname)
|
|
|
|
p := patch{
|
|
timestamp: time.Now(),
|
|
change: change,
|
|
}
|
|
|
|
if patches, ok := m.patches[uint64(change.NodeID)]; ok {
|
|
m.patches[uint64(change.NodeID)] = append(patches, p)
|
|
} else {
|
|
m.patches[uint64(change.NodeID)] = []patch{p}
|
|
}
|
|
}
|
|
}
|
|
|
|
if !sendUpdate {
|
|
return nil, nil
|
|
}
|
|
|
|
resp := m.baseMapResponse()
|
|
resp.PeersChangedPatch = changed
|
|
|
|
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
|
}
|
|
|
|
// TODO(kradalby): We need some integration tests for this.
|
|
func (m *Mapper) PeerRemovedResponse(
|
|
mapRequest tailcfg.MapRequest,
|
|
node *types.Node,
|
|
removed []tailcfg.NodeID,
|
|
) ([]byte, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Some nodes might have been removed already
|
|
// so we dont want to ask downstream to remove
|
|
// twice, than can cause a panic in tailscaled.
|
|
notYetRemoved := []tailcfg.NodeID{}
|
|
|
|
// remove from our internal map
|
|
for _, id := range removed {
|
|
if _, ok := m.peers[uint64(id)]; ok {
|
|
notYetRemoved = append(notYetRemoved, id)
|
|
}
|
|
|
|
delete(m.peers, uint64(id))
|
|
delete(m.patches, uint64(id))
|
|
}
|
|
|
|
resp := m.baseMapResponse()
|
|
resp.PeersRemoved = notYetRemoved
|
|
|
|
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
|
}
|
|
|
|
func (m *Mapper) marshalMapResponse(
|
|
mapRequest tailcfg.MapRequest,
|
|
resp *tailcfg.MapResponse,
|
|
node *types.Node,
|
|
compression string,
|
|
messages ...string,
|
|
) ([]byte, error) {
|
|
atomic.AddUint64(&m.seq, 1)
|
|
|
|
jsonBody, err := json.Marshal(resp)
|
|
if err != nil {
|
|
log.Error().
|
|
Caller().
|
|
Err(err).
|
|
Msg("Cannot marshal map response")
|
|
}
|
|
|
|
if debugDumpMapResponsePath != "" {
|
|
data := map[string]interface{}{
|
|
"Messages": messages,
|
|
"MapRequest": mapRequest,
|
|
"MapResponse": resp,
|
|
}
|
|
|
|
responseType := "keepalive"
|
|
|
|
switch {
|
|
case resp.Peers != nil && len(resp.Peers) > 0:
|
|
responseType = "full"
|
|
case isSelfUpdate(messages...):
|
|
responseType = "self"
|
|
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil:
|
|
responseType = "lite"
|
|
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
|
|
responseType = "changed"
|
|
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
|
|
responseType = "patch"
|
|
case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0:
|
|
responseType = "removed"
|
|
}
|
|
|
|
body, err := json.MarshalIndent(data, "", " ")
|
|
if err != nil {
|
|
log.Error().
|
|
Caller().
|
|
Err(err).
|
|
Msg("Cannot marshal map response")
|
|
}
|
|
|
|
perms := fs.FileMode(debugMapResponsePerm)
|
|
mPath := path.Join(debugDumpMapResponsePath, node.Hostname)
|
|
err = os.MkdirAll(mPath, perms)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
now := time.Now().UnixNano()
|
|
|
|
mapResponsePath := path.Join(
|
|
mPath,
|
|
fmt.Sprintf("%d-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
|
|
)
|
|
|
|
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
|
err = os.WriteFile(mapResponsePath, body, perms)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
var respBody []byte
|
|
if compression == util.ZstdCompression {
|
|
respBody = zstdEncode(jsonBody)
|
|
} else {
|
|
respBody = jsonBody
|
|
}
|
|
|
|
data := make([]byte, reservedResponseHeaderSize)
|
|
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
|
data = append(data, respBody...)
|
|
|
|
return data, nil
|
|
}
|
|
|
|
func zstdEncode(in []byte) []byte {
|
|
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
|
if !ok {
|
|
panic("invalid type in sync pool")
|
|
}
|
|
out := encoder.EncodeAll(in, nil)
|
|
_ = encoder.Close()
|
|
zstdEncoderPool.Put(encoder)
|
|
|
|
return out
|
|
}
|
|
|
|
var zstdEncoderPool = &sync.Pool{
|
|
New: func() any {
|
|
encoder, err := smallzstd.NewEncoder(
|
|
nil,
|
|
zstd.WithEncoderLevel(zstd.SpeedFastest))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return encoder
|
|
},
|
|
}
|
|
|
|
// baseMapResponse returns a tailcfg.MapResponse with
|
|
// KeepAlive false and ControlTime set to now.
|
|
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
|
now := time.Now()
|
|
|
|
resp := tailcfg.MapResponse{
|
|
KeepAlive: false,
|
|
ControlTime: &now,
|
|
// TODO(kradalby): Implement PingRequest?
|
|
}
|
|
|
|
return resp
|
|
}
|
|
|
|
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
|
|
// with the basic configuration from headscale set.
|
|
// It is used in for bigger updates, such as full and lite, not
|
|
// incremental.
|
|
func (m *Mapper) baseWithConfigMapResponse(
|
|
node *types.Node,
|
|
pol *policy.ACLPolicy,
|
|
capVer tailcfg.CapabilityVersion,
|
|
) (*tailcfg.MapResponse, error) {
|
|
resp := m.baseMapResponse()
|
|
|
|
tailnode, err := tailNode(node, capVer, pol, m.dnsCfg, m.baseDomain, m.randomClientPort)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp.Node = tailnode
|
|
|
|
resp.DERPMap = m.derpMap
|
|
|
|
resp.Domain = m.baseDomain
|
|
|
|
// Do not instruct clients to collect services we do not
|
|
// support or do anything with them
|
|
resp.CollectServices = "false"
|
|
|
|
resp.KeepAlive = false
|
|
|
|
resp.Debug = &tailcfg.Debug{
|
|
DisableLogTail: !m.logtail,
|
|
}
|
|
|
|
return &resp, nil
|
|
}
|
|
|
|
func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
|
|
ret := make(types.Nodes, 0)
|
|
|
|
for _, node := range nodes {
|
|
ret = append(ret, node)
|
|
}
|
|
|
|
return ret
|
|
}
|
|
|
|
// appendPeerChanges mutates a tailcfg.MapResponse with all the
|
|
// necessary changes when peers have changed.
|
|
func appendPeerChanges(
|
|
resp *tailcfg.MapResponse,
|
|
|
|
pol *policy.ACLPolicy,
|
|
node *types.Node,
|
|
capVer tailcfg.CapabilityVersion,
|
|
peers types.Nodes,
|
|
changed types.Nodes,
|
|
baseDomain string,
|
|
dnsCfg *tailcfg.DNSConfig,
|
|
randomClientPort bool,
|
|
) error {
|
|
fullChange := len(peers) == len(changed)
|
|
|
|
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
|
|
pol,
|
|
node,
|
|
peers,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// If there are filter rules present, see if there are any nodes that cannot
|
|
// access eachother at all and remove them from the peers.
|
|
if len(rules) > 0 {
|
|
changed = policy.FilterNodesByACL(node, changed, rules)
|
|
}
|
|
|
|
profiles := generateUserProfiles(node, changed, baseDomain)
|
|
|
|
dnsConfig := generateDNSConfig(
|
|
dnsCfg,
|
|
baseDomain,
|
|
node,
|
|
peers,
|
|
)
|
|
|
|
tailPeers, err := tailNodes(changed, capVer, pol, dnsCfg, baseDomain, randomClientPort)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Peers is always returned sorted by Node.ID.
|
|
sort.SliceStable(tailPeers, func(x, y int) bool {
|
|
return tailPeers[x].ID < tailPeers[y].ID
|
|
})
|
|
|
|
if fullChange {
|
|
resp.Peers = tailPeers
|
|
} else {
|
|
resp.PeersChanged = tailPeers
|
|
}
|
|
resp.DNSConfig = dnsConfig
|
|
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
|
|
resp.UserProfiles = profiles
|
|
resp.SSHPolicy = sshPolicy
|
|
|
|
return nil
|
|
}
|
|
|
|
func isSelfUpdate(messages ...string) bool {
|
|
for _, message := range messages {
|
|
if strings.Contains(message, types.SelfUpdateIdentifier) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|