Get peers from namespaces where shared nodes are shared to

This is rather shameful. Shared nodes should have never worked without this.
This commit is contained in:
Juan Font Alonso 2021-10-12 17:20:14 +02:00
parent dd1e425d02
commit fa8cd96108

View file

@ -78,13 +78,13 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
return machines, nil return machines, nil
} }
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for
func (h *Headscale) getShared(m *Machine) (Machines, error) { func (h *Headscale) getShared(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Str("func", "getShared"). Str("func", "getShared").
Str("machine", m.Name). Str("machine", m.Name).
Msg("Finding shared peers") Msg("Finding shared peers")
// We fetch here machines that are shared to the `Namespace` of the machine we are getting peers for
sharedMachines := []SharedMachine{} sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?", if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?",
m.NamespaceID).Find(&sharedMachines).Error; err != nil { m.NamespaceID).Find(&sharedMachines).Error; err != nil {
@ -105,6 +105,37 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
return peers, nil return peers, nil
} }
// getSharedTo fetches the machines of the namespaces this machine is shared in
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
log.Trace().
Str("func", "getSharedTo").
Str("machine", m.Name).
Msg("Finding peers in namespaces this machine is shared with")
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Where("machine_id = ?",
m.ID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err
}
peers := make(Machines, 0)
for _, sharedMachine := range sharedMachines {
namespaceMachines, err := h.ListMachinesInNamespace(sharedMachine.Namespace.Name)
if err != nil {
return Machines{}, err
}
peers = append(peers, *namespaceMachines...)
}
sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
log.Trace().
Str("func", "getSharedTo").
Str("machine", m.Name).
Msgf("Found peers we are shared with: %s", peers.String())
return peers, nil
}
func (h *Headscale) getPeers(m *Machine) (Machines, error) { func (h *Headscale) getPeers(m *Machine) (Machines, error) {
direct, err := h.getDirectPeers(m) direct, err := h.getDirectPeers(m)
if err != nil { if err != nil {
@ -118,13 +149,24 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
shared, err := h.getShared(m) shared, err := h.getShared(m)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "getDirectPeers"). Str("func", "getShared").
Err(err).
Msg("Cannot fetch peers")
return Machines{}, err
}
sharedTo, err := h.getSharedTo(m)
if err != nil {
log.Error().
Str("func", "sharedTo").
Err(err). Err(err).
Msg("Cannot fetch peers") Msg("Cannot fetch peers")
return Machines{}, err return Machines{}, err
} }
peers := append(direct, shared...) peers := append(direct, shared...)
peers = append(peers, sharedTo...)
sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
log.Trace(). log.Trace().