diff --git a/acls.go b/acls.go index db2fc58..ce14a89 100644 --- a/acls.go +++ b/acls.go @@ -90,7 +90,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { return nil, errEmptyPolicy } - machines, err := h.ListAllMachines() + machines, err := h.ListMachines() if err != nil { return nil, err } diff --git a/machine.go b/machine.go index 7656304..603441d 100644 --- a/machine.go +++ b/machine.go @@ -118,19 +118,6 @@ func (machine Machine) isExpired() bool { return time.Now().UTC().After(*machine.Expiry) } -func (h *Headscale) ListAllMachines() ([]Machine, error) { - machines := []Machine{} - if err := h.db.Preload("AuthKey"). - Preload("AuthKey.Namespace"). - Preload("Namespace"). - Where("registered"). - Find(&machines).Error; err != nil { - return nil, err - } - - return machines, nil -} - func containsAddresses(inputs []string, addrs []string) bool { for _, addr := range addrs { if containsString(inputs, addr) { @@ -215,15 +202,15 @@ func getFilteredByACLPeers( return authorizedPeers } -func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) { +func (h *Headscale) ListPeers(machine *Machine) (Machines, error) { log.Trace(). Caller(). Str("machine", machine.Name). Msg("Finding direct peers") machines := Machines{} - if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered", - machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil { + if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where("machine_key <> ? AND registered", + machine.MachineKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") return Machines{}, err @@ -234,7 +221,7 @@ func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) { log.Trace(). Caller(). Str("machine", machine.Name). - Msgf("Found direct machines: %s", machines.String()) + Msgf("Found peers: %s", machines.String()) return machines, nil } @@ -247,7 +234,7 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) { // else use the classic namespace scope if h.aclPolicy != nil { var machines []Machine - machines, err = h.ListAllMachines() + machines, err = h.ListMachines() if err != nil { log.Error().Err(err).Msg("Error retrieving list of machines") @@ -255,7 +242,7 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) { } peers = getFilteredByACLPeers(machines, h.aclRules, machine) } else { - peers, err = h.getDirectPeers(machine) + peers, err = h.ListPeers(machine) if err != nil { log.Error(). Caller(). diff --git a/machine_test.go b/machine_test.go index b1cd341..e9c91f8 100644 --- a/machine_test.go +++ b/machine_test.go @@ -118,7 +118,7 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) { c.Assert(err, check.NotNil) } -func (s *Suite) TestGetDirectPeers(c *check.C) { +func (s *Suite) TestListPeers(c *check.C) { namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) @@ -149,7 +149,7 @@ func (s *Suite) TestGetDirectPeers(c *check.C) { _, err = machine0ByID.GetHostInfo() c.Assert(err, check.IsNil) - peersOfMachine0, err := app.getDirectPeers(machine0ByID) + peersOfMachine0, err := app.ListPeers(machine0ByID) c.Assert(err, check.IsNil) c.Assert(len(peersOfMachine0), check.Equals, 9) @@ -222,7 +222,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { _, err = testMachine.GetHostInfo() c.Assert(err, check.IsNil) - machines, err := app.ListAllMachines() + machines, err := app.ListMachines() c.Assert(err, check.IsNil) peersOfTestMachine := getFilteredByACLPeers(machines, app.aclRules, testMachine)