diff --git a/.github/workflows/test-integration-v2-TestACLHostsInNetMapTable.yaml b/.github/workflows/test-integration-v2-TestACLHostsInNetMapTable.yaml new file mode 100644 index 0000000..70b7b94 --- /dev/null +++ b/.github/workflows/test-integration-v2-TestACLHostsInNetMapTable.yaml @@ -0,0 +1,57 @@ +# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go +# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ + +name: Integration Test v2 - TestACLHostsInNetMapTable + +on: [pull_request] + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v34 + with: + files: | + *.nix + go.* + **/*.go + integration_test/ + config-example.yaml + + - uses: cachix/install-nix-action@v18 + if: ${{ env.ACT }} || steps.changed-files.outputs.any_changed == 'true' + + - name: Run general integration tests + if: steps.changed-files.outputs.any_changed == 'true' + run: | + nix develop --command -- docker run \ + --tty --rm \ + --volume ~/.cache/hs-integration-go:/go \ + --name headscale-test-suite \ + --volume $PWD:$PWD -w $PWD/integration \ + --volume /var/run/docker.sock:/var/run/docker.sock \ + --volume $PWD/control_logs:/tmp/control \ + golang:1 \ + go test ./... \ + -tags ts2019 \ + -failfast \ + -timeout 120m \ + -parallel 1 \ + -run "^TestACLHostsInNetMapTable$" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: logs + path: "control_logs/*.log" diff --git a/acls.go b/acls.go index 86239d4..70295ea 100644 --- a/acls.go +++ b/acls.go @@ -133,6 +133,14 @@ func (h *Headscale) UpdateACLRules() error { log.Trace().Interface("ACL", rules).Msg("ACL rules generated") h.aclRules = rules + // Precompute a map of which sources can reach each destination, this is + // to provide quicker lookup when we calculate the peerlist for the map + // response to nodes. + aclPeerCacheMap := generateACLPeerCacheMap(rules) + h.aclPeerCacheMapRW.Lock() + h.aclPeerCacheMap = aclPeerCacheMap + h.aclPeerCacheMapRW.Unlock() + if featureEnableSSH() { sshRules, err := h.generateSSHRules() if err != nil { @@ -150,6 +158,30 @@ func (h *Headscale) UpdateACLRules() error { return nil } +// generateACLPeerCacheMap takes a list of Tailscale filter rules and generates a map +// of which Sources ("*" and IPs) can access destinations. This is to speed up the +// process of generating MapResponses when deciding which Peers to inform nodes about. +func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]struct{} { + aclCachePeerMap := make(map[string]map[string]struct{}) + for _, rule := range rules { + for _, srcIP := range rule.SrcIPs { + if data, ok := aclCachePeerMap[srcIP]; ok { + for _, dstPort := range rule.DstPorts { + data[dstPort.IP] = struct{}{} + } + } else { + dstPortsMap := make(map[string]struct{}, len(rule.DstPorts)) + for _, dstPort := range rule.DstPorts { + dstPortsMap[dstPort.IP] = struct{}{} + } + aclCachePeerMap[srcIP] = dstPortsMap + } + } + } + + return aclCachePeerMap +} + func generateACLRules( machines []Machine, aclPolicy ACLPolicy, diff --git a/app.go b/app.go index 219f64f..26a8e23 100644 --- a/app.go +++ b/app.go @@ -84,9 +84,11 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *DERPServer - aclPolicy *ACLPolicy - aclRules []tailcfg.FilterRule - sshPolicy *tailcfg.SSHPolicy + aclPolicy *ACLPolicy + aclRules []tailcfg.FilterRule + aclPeerCacheMapRW sync.RWMutex + aclPeerCacheMap map[string]map[string]struct{} + sshPolicy *tailcfg.SSHPolicy lastStateChange *xsync.MapOf[string, time.Time] diff --git a/integration/acl_test.go b/integration/acl_test.go new file mode 100644 index 0000000..7e3eaa5 --- /dev/null +++ b/integration/acl_test.go @@ -0,0 +1,181 @@ +package integration + +import ( + "testing" + + "github.com/juanfont/headscale" + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/assert" +) + +// This tests a different ACL mechanism, if a host _cannot_ connect +// to another node at all based on ACL, it should just not be part +// of the NetMap sent to the host. This is slightly different than +// the other tests as we can just check if the hosts are present +// or not. +func TestACLHostsInNetMapTable(t *testing.T) { + IntegrationSkip(t) + + // NOTE: All want cases currently checks the + // total count of expected peers, this would + // typically be the client count of the users + // they can access minus one (them self). + tests := map[string]struct { + users map[string]int + policy headscale.ACLPolicy + want map[string]int + }{ + // Test that when we have no ACL, each client netmap has + // the amount of peers of the total amount of clients + "base-acls": { + users: map[string]int{ + "user1": 2, + "user2": 2, + }, + policy: headscale.ACLPolicy{ + ACLs: []headscale.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + }, want: map[string]int{ + "user1": 3, // ns1 + ns2 + "user2": 3, // ns2 + ns1 + }, + }, + // Test that when we have two users, which cannot see + // eachother, each node has only the number of pairs from + // their own user. + "two-isolated-users": { + users: map[string]int{ + "user1": 2, + "user2": 2, + }, + policy: headscale.ACLPolicy{ + ACLs: []headscale.ACL{ + { + Action: "accept", + Sources: []string{"user1"}, + Destinations: []string{"user1:*"}, + }, + { + Action: "accept", + Sources: []string{"user2"}, + Destinations: []string{"user2:*"}, + }, + }, + }, want: map[string]int{ + "user1": 1, + "user2": 1, + }, + }, + // Test that when we have two users, with ACLs and they + // are restricted to a single port, nodes are still present + // in the netmap. + "two-restricted-present-in-netmap": { + users: map[string]int{ + "user1": 2, + "user2": 2, + }, + policy: headscale.ACLPolicy{ + ACLs: []headscale.ACL{ + { + Action: "accept", + Sources: []string{"user1"}, + Destinations: []string{"user1:22"}, + }, + { + Action: "accept", + Sources: []string{"user2"}, + Destinations: []string{"user2:22"}, + }, + { + Action: "accept", + Sources: []string{"user1"}, + Destinations: []string{"user2:22"}, + }, + { + Action: "accept", + Sources: []string{"user2"}, + Destinations: []string{"user1:22"}, + }, + }, + }, want: map[string]int{ + "user1": 3, + "user2": 3, + }, + }, + // Test that when we have two users, that are isolated, + // but one can see the others, we have the appropriate number + // of peers. This will still result in all the peers as we + // need them present on the other side for the "return path". + "two-ns-one-isolated": { + users: map[string]int{ + "user1": 2, + "user2": 2, + }, + policy: headscale.ACLPolicy{ + ACLs: []headscale.ACL{ + { + Action: "accept", + Sources: []string{"user1"}, + Destinations: []string{"user1:*"}, + }, + { + Action: "accept", + Sources: []string{"user2"}, + Destinations: []string{"user2:*"}, + }, + { + Action: "accept", + Sources: []string{"user1"}, + Destinations: []string{"user2:*"}, + }, + }, + }, want: map[string]int{ + "user1": 3, // ns1 + ns2 + "user2": 3, // ns1 + ns2 (return path) + }, + }, + } + + for name, testCase := range tests { + t.Run(name, func(t *testing.T) { + scenario, err := NewScenario() + assert.NoError(t, err) + + spec := testCase.users + + err = scenario.CreateHeadscaleEnv(spec, + []tsic.Option{}, + hsic.WithACLPolicy(&testCase.policy), + // hsic.WithTestName(fmt.Sprintf("aclinnetmap%s", name)), + ) + assert.NoError(t, err) + + allClients, err := scenario.ListTailscaleClients() + assert.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + assert.NoError(t, err) + + // allHostnames, err := scenario.ListTailscaleClientsFQDNs() + // assert.NoError(t, err) + + for _, client := range allClients { + status, err := client.Status() + assert.NoError(t, err) + + user := status.User[status.Self.UserID].LoginName + + assert.Equal(t, (testCase.want[user]), len(status.Peer)) + } + + err = scenario.Shutdown() + assert.NoError(t, err) + }) + } +} diff --git a/machine.go b/machine.go index c31e384..fd6e2ed 100644 --- a/machine.go +++ b/machine.go @@ -8,6 +8,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -160,32 +161,18 @@ func (machine *Machine) isEphemeral() bool { return machine.AuthKey != nil && machine.AuthKey.Ephemeral } -func containsAddresses(inputs []string, addrs []string) bool { - for _, addr := range addrs { - if containsStr(inputs, addr) { - return true - } - } - - return false +// filterMachinesByACL wrapper function to not have devs pass around locks and maps +// related to the application outside of tests. +func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines) Machines { + return filterMachinesByACL(currentMachine, peers, &h.aclPeerCacheMapRW, h.aclPeerCacheMap) } -// matchSourceAndDestinationWithRule. -func matchSourceAndDestinationWithRule( - ruleSources []string, - ruleDestinations []string, - source []string, - destination []string, -) bool { - return containsAddresses(ruleSources, source) && - containsAddresses(ruleDestinations, destination) -} - -// getFilteredByACLPeerss should return the list of peers authorized to be accessed from machine. -func getFilteredByACLPeers( - machines []Machine, - rules []tailcfg.FilterRule, +// filterMachinesByACL returns the list of peers authorized to be accessed from a given machine. +func filterMachinesByACL( machine *Machine, + machines []Machine, + lock *sync.RWMutex, + aclPeerCacheMap map[string]map[string]struct{}, ) Machines { log.Trace(). Caller(). @@ -196,57 +183,80 @@ func getFilteredByACLPeers( // Aclfilter peers here. We are itering through machines in all users and search through the computed aclRules // for match between rule SrcIPs and DstPorts. If the rule is a match we allow the machine to be viewable. machineIPs := machine.IPAddresses.ToStringSlice() + + // TODO(kradalby): Remove this lock, I suspect its not a good idea, and might not be necessary, + // we only set this at startup atm (reading ACLs) and it might become a bottleneck. + lock.RLock() + for _, peer := range machines { if peer.ID == machine.ID { continue } - for _, rule := range rules { - var dst []string - for _, d := range rule.DstPorts { - dst = append(dst, d.IP) - } - peerIPs := peer.IPAddresses.ToStringSlice() - if matchSourceAndDestinationWithRule( - rule.SrcIPs, - dst, - machineIPs, - peerIPs, - ) || // match source and destination - matchSourceAndDestinationWithRule( - rule.SrcIPs, - dst, - peerIPs, - machineIPs, - ) || // match return path - matchSourceAndDestinationWithRule( - rule.SrcIPs, - dst, - machineIPs, - []string{"*"}, - ) || // match source and all destination - matchSourceAndDestinationWithRule( - rule.SrcIPs, - dst, - []string{"*"}, - []string{"*"}, - ) || // match source and all destination - matchSourceAndDestinationWithRule( - rule.SrcIPs, - dst, - []string{"*"}, - peerIPs, - ) || // match source and all destination - matchSourceAndDestinationWithRule( - rule.SrcIPs, - dst, - []string{"*"}, - machineIPs, - ) { // match all sources and source + peerIPs := peer.IPAddresses.ToStringSlice() + + if dstMap, ok := aclPeerCacheMap["*"]; ok { + // match source and all destination + if _, dstOk := dstMap["*"]; dstOk { peers[peer.ID] = peer + + continue + } + + // match source and all destination + for _, peerIP := range peerIPs { + if _, dstOk := dstMap[peerIP]; dstOk { + peers[peer.ID] = peer + + continue + } + } + + // match all sources and source + for _, machineIP := range machineIPs { + if _, dstOk := dstMap[machineIP]; dstOk { + peers[peer.ID] = peer + + continue + } + } + } + + for _, machineIP := range machineIPs { + if dstMap, ok := aclPeerCacheMap[machineIP]; ok { + // match source and all destination + if _, dstOk := dstMap["*"]; dstOk { + peers[peer.ID] = peer + + continue + } + + // match source and destination + for _, peerIP := range peerIPs { + if _, dstOk := dstMap[peerIP]; dstOk { + peers[peer.ID] = peer + + continue + } + } + } + } + + for _, peerIP := range peerIPs { + if dstMap, ok := aclPeerCacheMap[peerIP]; ok { + // match return path + for _, machineIP := range machineIPs { + if _, dstOk := dstMap[machineIP]; dstOk { + peers[peer.ID] = peer + + continue + } + } } } } + lock.RUnlock() + authorizedPeers := make([]Machine, 0, len(peers)) for _, m := range peers { authorizedPeers = append(authorizedPeers, m) @@ -302,7 +312,7 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) { return Machines{}, err } - peers = getFilteredByACLPeers(machines, h.aclRules, machine) + peers = h.filterMachinesByACL(machine, machines) } else { peers, err = h.ListPeers(machine) if err != nil { diff --git a/machine_test.go b/machine_test.go index 8e35cbf..86eb191 100644 --- a/machine_test.go +++ b/machine_test.go @@ -6,6 +6,7 @@ import ( "reflect" "regexp" "strconv" + "sync" "testing" "time" @@ -277,8 +278,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { machines, err := app.ListMachines() c.Assert(err, check.IsNil) - peersOfTestMachine := getFilteredByACLPeers(machines, app.aclRules, testMachine) - peersOfAdminMachine := getFilteredByACLPeers(machines, app.aclRules, adminMachine) + peersOfTestMachine := app.filterMachinesByACL(testMachine, machines) + peersOfAdminMachine := app.filterMachinesByACL(adminMachine, machines) c.Log(peersOfTestMachine) c.Assert(len(peersOfTestMachine), check.Equals, 4) @@ -950,15 +951,19 @@ func Test_getFilteredByACLPeers(t *testing.T) { want: Machines{}, }, } + var lock sync.RWMutex for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := getFilteredByACLPeers( - tt.args.machines, - tt.args.rules, + aclRulesMap := generateACLPeerCacheMap(tt.args.rules) + + got := filterMachinesByACL( tt.args.machine, + tt.args.machines, + &lock, + aclRulesMap, ) if !reflect.DeepEqual(got, tt.want) { - t.Errorf("getFilteredByACLPeers() = %v, want %v", got, tt.want) + t.Errorf("filterMachinesByACL() = %v, want %v", got, tt.want) } }) } diff --git a/utils.go b/utils.go index 5d7f487..8bdb2b3 100644 --- a/utils.go +++ b/utils.go @@ -269,16 +269,6 @@ func stringToIPPrefix(prefixes []string) ([]netip.Prefix, error) { return result, nil } -func containsStr(ts []string, t string) bool { - for _, v := range ts { - if v == t { - return true - } - } - - return false -} - func contains[T string | netip.Prefix](ts []T, t T) bool { for _, v := range ts { if reflect.DeepEqual(v, t) {