Implement node sharing functionality

This commit is contained in:
Juan Font 2021-09-02 16:59:03 +02:00
parent 1ecd0d7ca4
commit 48b73fa12f
3 changed files with 54 additions and 26 deletions

5
api.go
View file

@ -33,8 +33,6 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
return return
} }
// spew.Dump(c.Params)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html> <html>
<body> <body>
@ -220,7 +218,7 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Mac
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname). Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response") Msg("Creating Map response")
node, err := m.toNode() node, err := m.toNode(true)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
@ -280,7 +278,6 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Mac
return nil, err return nil, err
} }
} }
// spew.Dump(resp)
// declare the incoming size on the first 4 bytes // declare the incoming size on the first 4 bytes
data := make([]byte, 4) data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, uint32(len(respBody))) binary.LittleEndian.PutUint32(data, uint32(len(respBody)))

View file

@ -50,7 +50,7 @@ func (m Machine) isAlreadyRegistered() bool {
return m.Registered return m.Registered
} }
func (m Machine) toNode() (*tailcfg.Node, error) { func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) {
nKey, err := wgkey.ParseHex(m.NodeKey) nKey, err := wgkey.ParseHex(m.NodeKey)
if err != nil { if err != nil {
return nil, err return nil, err
@ -85,24 +85,26 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
allowedIPs := []netaddr.IPPrefix{} allowedIPs := []netaddr.IPPrefix{}
allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients
routesStr := []string{} if includeRoutes {
if len(m.EnabledRoutes) != 0 { routesStr := []string{}
allwIps, err := m.EnabledRoutes.MarshalJSON() if len(m.EnabledRoutes) != 0 {
if err != nil { allwIps, err := m.EnabledRoutes.MarshalJSON()
return nil, err if err != nil {
return nil, err
}
err = json.Unmarshal(allwIps, &routesStr)
if err != nil {
return nil, err
}
} }
err = json.Unmarshal(allwIps, &routesStr)
if err != nil {
return nil, err
}
}
for _, aip := range routesStr { for _, aip := range routesStr {
ip, err := netaddr.ParseIPPrefix(aip) ip, err := netaddr.ParseIPPrefix(aip)
if err != nil { if err != nil {
return nil, err return nil, err
}
allowedIPs = append(allowedIPs, ip)
} }
allowedIPs = append(allowedIPs, ip)
} }
endpoints := []string{} endpoints := []string{}
@ -136,13 +138,20 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
derp = "127.3.3.40:0" // Zero means disconnected or unknown. derp = "127.3.3.40:0" // Zero means disconnected or unknown.
} }
var keyExpiry time.Time
if m.Expiry != nil {
keyExpiry = *m.Expiry
} else {
keyExpiry = time.Time{}
}
n := tailcfg.Node{ n := tailcfg.Node{
ID: tailcfg.NodeID(m.ID), // this is the actual ID ID: tailcfg.NodeID(m.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostinfo.Hostname, Name: hostinfo.Hostname,
User: tailcfg.UserID(m.NamespaceID), User: tailcfg.UserID(m.NamespaceID),
Key: tailcfg.NodeKey(nKey), Key: tailcfg.NodeKey(nKey),
KeyExpiry: *m.Expiry, KeyExpiry: keyExpiry,
Machine: tailcfg.MachineKey(mKey), Machine: tailcfg.MachineKey(mKey),
DiscoKey: discoKey, DiscoKey: discoKey,
Addresses: addrs, Addresses: addrs,
@ -165,6 +174,7 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
Str("func", "getPeers"). Str("func", "getPeers").
Str("machine", m.Name). Str("machine", m.Name).
Msg("Finding peers") Msg("Finding peers")
machines := []Machine{} machines := []Machine{}
if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered", if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered",
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
@ -172,9 +182,23 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
return nil, err return nil, err
} }
// We fetch here machines that are shared to the `Namespace` of the machine we are getting peers for
sharedNodes := []SharedNode{}
if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?",
m.NamespaceID).Find(&sharedNodes).Error; err != nil {
return nil, err
}
peers := []*tailcfg.Node{} peers := []*tailcfg.Node{}
for _, mn := range machines { for _, mn := range machines {
peer, err := mn.toNode() peer, err := mn.toNode(true)
if err != nil {
return nil, err
}
peers = append(peers, peer)
}
for _, sn := range sharedNodes {
peer, err := sn.Machine.toNode(false) // shared nodes do not expose their routes
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -201,7 +225,7 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
return &m, nil return &m, nil
} }
} }
return nil, fmt.Errorf("not found") return nil, fmt.Errorf("machine not found")
} }
// GetMachineByID finds a Machine by ID and returns the Machine struct // GetMachineByID finds a Machine by ID and returns the Machine struct
@ -260,7 +284,14 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
} }
func (h *Headscale) notifyChangesToPeers(m *Machine) { func (h *Headscale) notifyChangesToPeers(m *Machine) {
peers, _ := h.getPeers(*m) peers, err := h.getPeers(*m)
if err != nil {
log.Error().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Msgf("Error getting peers: %s", err)
return
}
for _, p := range *peers { for _, p := range *peers {
log.Info(). log.Info().
Str("func", "notifyChangesToPeers"). Str("func", "notifyChangesToPeers").

View file

@ -440,7 +440,7 @@ func (h *Headscale) scheduledPollWorker(
case <-updateCheckerTicker.C: case <-updateCheckerTicker.C:
// Send an update request regardless of outdated or not, if data is sent // Send an update request regardless of outdated or not, if data is sent
// to the node is determined in the updateChan consumer block // to the node is determined in the updateChan consumer block
n, _ := m.toNode() n, _ := m.toNode(true)
err := h.sendRequestOnUpdateChannel(n) err := h.sendRequestOnUpdateChannel(n)
if err != nil { if err != nil {
log.Error(). log.Error().