diff --git a/machine.go b/machine.go index aef310c..231c179 100644 --- a/machine.go +++ b/machine.go @@ -119,14 +119,21 @@ func (machine Machine) isExpired() bool { return time.Now().UTC().After(*machine.Expiry) } -// Our Pineapple fork of Headscale ignores namespaces when dealing with peers -// and instead passes ALL peers across all namespaces to each client. Access between clients -// is then enforced with ACL policies. -func (h *Headscale) getAllPeers(machine *Machine) (Machines, error) { +func containsAddresses(inputs []string, addrs MachineAddresses) bool { + for _, addr := range addrs.ToStringSlice() { + if containsString(inputs, addr) { + return true + } + } + return false +} + +// getFilteredByACLPeerss should return the list of peers authorized to be accessed from machine. +func (h *Headscale) getFilteredByACLPeers(machine *Machine) (Machines, error) { log.Trace(). Caller(). Str("machine", machine.Name). - Msg("Finding all peers") + Msg("Finding peers filtered by ACLs") machines := Machines{} if err := h.db.Preload("Namespace").Where("machine_key <> ? AND registered", @@ -135,15 +142,50 @@ func (h *Headscale) getAllPeers(machine *Machine) (Machines, error) { return Machines{}, err } + mMachines := make(map[uint64]Machine) - sort.Slice(machines, func(i, j int) bool { return machines[i].ID < machines[j].ID }) + // Aclfilter peers here. We are itering through machines in all namespaces and search through the computed aclRules + // for match between rule SrcIPs and DstPorts. If the rule is a match we allow the machine to be viewable. + + // FIXME: On official control plane if a rule allow user A to talk to user B but NO rule allows user B to talk to + // userĀ A. The behaviour is the following + // + // On official tailscale control plane: + // on first `tailscale status`` on node A we can see node B. The `tailscale status` command on node B doesn't show node A + // We can successfully establish a communication from A to B. When it's done, if we run the `tailscale status` command + // on node B again we can now see node A. It's not possible to establish a communication from node B to node A. + // On this implementation of the feature + // on any `tailscale status` command on node A we can see node B. The `tailscale status` command on node B DOES show A. + // + // I couldn't find a way to not clutter the output of `tailscale status` with all nodes that we could be talking to. + // In order to do this we would need to be able to identify that node A want to talk to node B but that Node B doesn't know + // how to talk to node A and then add the peering resource. + + for _, m := range machines { + for _, rule := range h.aclRules { + var dst []string + for _, d := range rule.DstPorts { + dst = append(dst, d.IP) + } + if (containsAddresses(rule.SrcIPs, machine.IPAddresses) && (containsAddresses(dst, m.IPAddresses) || containsString(dst, "*"))) || + (containsAddresses(rule.SrcIPs, m.IPAddresses) && containsAddresses(dst, machine.IPAddresses)) { + mMachines[m.ID] = m + } + } + } + + var authorizedMachines Machines + for _, m := range mMachines { + authorizedMachines = append(authorizedMachines, m) + } + sort.Slice(authorizedMachines, func(i, j int) bool { return authorizedMachines[i].ID < authorizedMachines[j].ID }) log.Trace(). Caller(). Str("machine", machine.Name). - Msgf("Found all machines: %s", machines.String()) + Msgf("Found some machines: %s", machines.String()) - return machines, nil + return authorizedMachines, nil } func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) { @@ -233,47 +275,52 @@ func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) { } func (h *Headscale) getPeers(machine *Machine) (Machines, error) { - // direct, err := h.getDirectPeers(machine) - // if err != nil { - // log.Error(). - // Caller(). - // Err(err). - // Msg("Cannot fetch peers") + var peers Machines + var err error + // If ACLs rules are defined, filter visible host list with the ACLs + // else use the classic namespace scope + if h.aclPolicy != nil { + peers, err = h.getFilteredByACLPeers(machine) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot fetch peers") - // return Machines{}, err - // } + return Machines{}, err + } + } else { + direct, err := h.getDirectPeers(machine) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot fetch peers") - // shared, err := h.getShared(machine) - // if err != nil { - // log.Error(). - // Caller(). - // Err(err). - // Msg("Cannot fetch peers") + return Machines{}, err + } - // return Machines{}, err - // } + shared, err := h.getShared(machine) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot fetch peers") - // sharedTo, err := h.getSharedTo(machine) - // if err != nil { - // log.Error(). - // Caller(). - // Err(err). - // Msg("Cannot fetch peers") + return Machines{}, err + } - // return Machines{}, err - // } + sharedTo, err := h.getSharedTo(machine) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot fetch peers") - // peers := append(direct, shared...) - // peers = append(peers, sharedTo...) - - peers, err := h.getAllPeers(machine) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot fetch peers") - - return Machines{}, err + return Machines{}, err + } + peers = append(direct, shared...) + peers = append(peers, sharedTo...) } sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) diff --git a/machine_test.go b/machine_test.go index ff1dc91..f84603a 100644 --- a/machine_test.go +++ b/machine_test.go @@ -1,6 +1,7 @@ package headscale import ( + "fmt" "strconv" "time" @@ -154,6 +155,89 @@ func (s *Suite) TestGetDirectPeers(c *check.C) { c.Assert(peersOfMachine0[8].Name, check.Equals, "testmachine10") } +func (s *Suite) TestGetACLFilteredPeers(c *check.C) { + type base struct { + namespace *Namespace + key *PreAuthKey + } + + var stor []base + + for _, name := range []string{"test", "admin"} { + namespace, err := app.CreateNamespace(name) + c.Assert(err, check.IsNil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) + c.Assert(err, check.IsNil) + stor = append(stor, base{namespace, pak}) + + } + + _, err := app.GetMachineByID(0) + c.Assert(err, check.NotNil) + + for index := 0; index <= 10; index++ { + machine := Machine{ + ID: uint64(index), + MachineKey: "foo" + strconv.Itoa(index), + NodeKey: "bar" + strconv.Itoa(index), + DiscoKey: "faa" + strconv.Itoa(index), + IPAddress: fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)), + Name: "testmachine" + strconv.Itoa(index), + NamespaceID: stor[index%2].namespace.ID, + Registered: true, + RegisterMethod: RegisterMethodAuthKey, + AuthKeyID: uint(stor[index%2].key.ID), + } + app.db.Save(&machine) + } + + app.aclPolicy = &ACLPolicy{ + Groups: map[string][]string{ + "group:test": {"admin"}, + }, + Hosts: map[string]netaddr.IPPrefix{}, + TagOwners: map[string][]string{}, + ACLs: []ACL{ + {Action: "accept", Users: []string{"admin"}, Ports: []string{"*:*"}}, + {Action: "accept", Users: []string{"test"}, Ports: []string{"test:*"}}, + }, + Tests: []ACLTest{}, + } + + rules, err := app.generateACLRules() + c.Assert(err, check.IsNil) + app.aclRules = rules + + adminMachine, err := app.GetMachineByID(1) + c.Logf("Machine(%v), namespace: %v", adminMachine.Name, adminMachine.Namespace) + c.Assert(err, check.IsNil) + + testMachine, err := app.GetMachineByID(2) + c.Logf("Machine(%v), namespace: %v", testMachine.Name, testMachine.Namespace) + c.Assert(err, check.IsNil) + + _, err = testMachine.GetHostInfo() + c.Assert(err, check.IsNil) + + peersOfTestMachine, err := app.getFilteredByACLPeers(testMachine) + c.Assert(err, check.IsNil) + + peersOfAdminMachine, err := app.getFilteredByACLPeers(adminMachine) + c.Assert(err, check.IsNil) + + c.Log(peersOfTestMachine) + c.Assert(len(peersOfTestMachine), check.Equals, 4) + c.Assert(peersOfTestMachine[0].Name, check.Equals, "testmachine4") + c.Assert(peersOfTestMachine[1].Name, check.Equals, "testmachine6") + c.Assert(peersOfTestMachine[3].Name, check.Equals, "testmachine10") + + c.Log(peersOfAdminMachine) + c.Assert(len(peersOfAdminMachine), check.Equals, 9) + c.Assert(peersOfAdminMachine[0].Name, check.Equals, "testmachine2") + c.Assert(peersOfAdminMachine[2].Name, check.Equals, "testmachine4") + c.Assert(peersOfAdminMachine[5].Name, check.Equals, "testmachine7") +} + func (s *Suite) TestExpireMachine(c *check.C) { namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) diff --git a/utils.go b/utils.go index a6be8cc..794971a 100644 --- a/utils.go +++ b/utils.go @@ -212,6 +212,15 @@ func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) { return ips, nil } +func containsString(ss []string, s string) bool { + for _, v := range ss { + if v == s { + return true + } + } + return false +} + func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool { for _, v := range ips { if v == ip {