diff --git a/acls.go b/acls.go index 9233cb4..6e23e0a 100644 --- a/acls.go +++ b/acls.go @@ -14,6 +14,7 @@ import ( "github.com/rs/zerolog/log" "github.com/tailscale/hujson" + "go4.org/netipx" "gopkg.in/yaml.v3" "tailscale.com/envknob" "tailscale.com/tailcfg" @@ -165,16 +166,22 @@ func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]s aclCachePeerMap := make(map[string]map[string]struct{}) for _, rule := range rules { for _, srcIP := range rule.SrcIPs { - if data, ok := aclCachePeerMap[srcIP]; ok { - for _, dstPort := range rule.DstPorts { - data[dstPort.IP] = struct{}{} + for _, ip := range expandACLPeerAddr(srcIP) { + if data, ok := aclCachePeerMap[ip]; ok { + for _, dstPort := range rule.DstPorts { + for _, dstIP := range expandACLPeerAddr(dstPort.IP) { + data[dstIP] = struct{}{} + } + } + } else { + dstPortsMap := make(map[string]struct{}, len(rule.DstPorts)) + for _, dstPort := range rule.DstPorts { + for _, dstIP := range expandACLPeerAddr(dstPort.IP) { + dstPortsMap[dstIP] = struct{}{} + } + } + aclCachePeerMap[ip] = dstPortsMap } - } else { - dstPortsMap := make(map[string]struct{}, len(rule.DstPorts)) - for _, dstPort := range rule.DstPorts { - dstPortsMap[dstPort.IP] = struct{}{} - } - aclCachePeerMap[srcIP] = dstPortsMap } } } @@ -184,6 +191,41 @@ func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]s return aclCachePeerMap } +// expandACLPeerAddr takes a "tailcfg.FilterRule" "IP" and expands it into +// something our cache logic can look up, which is "*" or single IP addresses. +// This is probably quite inefficient, but it is a result of +// "make it work, then make it fast", and a lot of the ACL stuff does not +// work, but people have tried to make it fast. +func expandACLPeerAddr(srcIP string) []string { + if ip, err := netip.ParseAddr(srcIP); err == nil { + return []string{ip.String()} + } + + if cidr, err := netip.ParsePrefix(srcIP); err == nil { + addrs := []string{} + + ipRange := netipx.RangeOfPrefix(cidr) + + from := ipRange.From() + too := ipRange.To() + + if from == too { + return []string{from.String()} + } + + for from != too { + addrs = append(addrs, from.String()) + + from = from.Next() + } + + return addrs + } + + // probably "*" or other string based "IP" + return []string{srcIP} +} + func generateACLRules( machines []Machine, aclPolicy ACLPolicy, diff --git a/acls_test.go b/acls_test.go index 8bd8585..c4d619d 100644 --- a/acls_test.go +++ b/acls_test.go @@ -1556,3 +1556,67 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { }) } } + +func Test_expandACLPeerAddr(t *testing.T) { + type args struct { + srcIP string + } + tests := []struct { + name string + args args + want []string + }{ + { + name: "asterix", + args: args{ + srcIP: "*", + }, + want: []string{"*"}, + }, + { + name: "ip", + args: args{ + srcIP: "10.0.0.1", + }, + want: []string{"10.0.0.1"}, + }, + { + name: "ip/32", + args: args{ + srcIP: "10.0.0.1/32", + }, + want: []string{"10.0.0.1"}, + }, + { + name: "ip/30", + args: args{ + srcIP: "10.0.0.1/30", + }, + want: []string{ + "10.0.0.0", + "10.0.0.1", + "10.0.0.2", + }, + }, + { + name: "ip/28", + args: args{ + srcIP: "192.168.0.128/28", + }, + want: []string{ + "192.168.0.128", "192.168.0.129", "192.168.0.130", + "192.168.0.131", "192.168.0.132", "192.168.0.133", + "192.168.0.134", "192.168.0.135", "192.168.0.136", + "192.168.0.137", "192.168.0.138", "192.168.0.139", + "192.168.0.140", "192.168.0.141", "192.168.0.142", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := expandACLPeerAddr(tt.args.srcIP); !reflect.DeepEqual(got, tt.want) { + t.Errorf("expandACLPeerAddr() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/integration/acl_test.go b/integration/acl_test.go index d704324..42f9b94 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -2,6 +2,7 @@ package integration import ( "fmt" + "net/netip" "strings" "testing" @@ -439,3 +440,214 @@ func TestACLAllowStarDst(t *testing.T) { err = scenario.Shutdown() assert.NoError(t, err) } + +// This test aims to cover cases where individual hosts are allowed and denied +// access based on their assigned hostname +// https://github.com/juanfont/headscale/issues/941 + +// ACL = [{ +// "DstPorts": [{ +// "Bits": null, +// "IP": "100.64.0.3/32", +// "Ports": { +// "First": 0, +// "Last": 65535 +// } +// }], +// "SrcIPs": ["*"] +// }, { +// +// "DstPorts": [{ +// "Bits": null, +// "IP": "100.64.0.2/32", +// "Ports": { +// "First": 0, +// "Last": 65535 +// } +// }], +// "SrcIPs": ["100.64.0.1/32"] +// }] +// +// ACL Cache Map= { +// "*": { +// "100.64.0.3/32": {} +// }, +// "100.64.0.1/32": { +// "100.64.0.2/32": {} +// } +// } +func TestACLNamedHostsCanReach(t *testing.T) { + IntegrationSkip(t) + + scenario := aclScenario(t, + headscale.ACLPolicy{ + Hosts: headscale.Hosts{ + "test1": netip.MustParsePrefix("100.64.0.1/32"), + "test2": netip.MustParsePrefix("100.64.0.2/32"), + "test3": netip.MustParsePrefix("100.64.0.3/32"), + }, + ACLs: []headscale.ACL{ + // Everyone can curl test3 + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"test3:*"}, + }, + // test1 can curl test2 + { + Action: "accept", + Sources: []string{"test1"}, + Destinations: []string{"test2:*"}, + }, + }, + }, + ) + + // Since user/users dont matter here, we basically expect that some clients + // will be assigned these ips and that we can pick them up for our own use. + test1ip := netip.MustParseAddr("100.64.0.1") + test1, err := scenario.FindTailscaleClientByIP(test1ip) + assert.NoError(t, err) + + test1fqdn, err := test1.FQDN() + assert.NoError(t, err) + test1ipURL := fmt.Sprintf("http://%s/etc/hostname", test1ip.String()) + test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn) + + test2ip := netip.MustParseAddr("100.64.0.2") + test2, err := scenario.FindTailscaleClientByIP(test2ip) + assert.NoError(t, err) + + test2fqdn, err := test2.FQDN() + assert.NoError(t, err) + test2ipURL := fmt.Sprintf("http://%s/etc/hostname", test2ip.String()) + test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn) + + test3ip := netip.MustParseAddr("100.64.0.3") + test3, err := scenario.FindTailscaleClientByIP(test3ip) + assert.NoError(t, err) + + test3fqdn, err := test3.FQDN() + assert.NoError(t, err) + test3ipURL := fmt.Sprintf("http://%s/etc/hostname", test3ip.String()) + test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn) + + // test1 can query test3 + result, err := test1.Curl(test3ipURL) + assert.Len(t, result, 13) + assert.NoError(t, err) + + result, err = test1.Curl(test3fqdnURL) + assert.Len(t, result, 13) + assert.NoError(t, err) + + // test2 can query test3 + result, err = test2.Curl(test3ipURL) + assert.Len(t, result, 13) + assert.NoError(t, err) + + result, err = test2.Curl(test3fqdnURL) + assert.Len(t, result, 13) + assert.NoError(t, err) + + // test3 cannot query test1 + result, err = test3.Curl(test1ipURL) + assert.Empty(t, result) + assert.Error(t, err) + + result, err = test3.Curl(test1fqdnURL) + assert.Empty(t, result) + assert.Error(t, err) + + // test3 cannot query test2 + result, err = test3.Curl(test2ipURL) + assert.Empty(t, result) + assert.Error(t, err) + + result, err = test3.Curl(test2fqdnURL) + assert.Empty(t, result) + assert.Error(t, err) + + // test1 can query test2 + result, err = test1.Curl(test2ipURL) + assert.Len(t, result, 13) + assert.NoError(t, err) + + result, err = test1.Curl(test2fqdnURL) + assert.Len(t, result, 13) + assert.NoError(t, err) + + // test2 cannot query test1 + result, err = test2.Curl(test1ipURL) + assert.Empty(t, result) + assert.Error(t, err) + + result, err = test2.Curl(test1fqdnURL) + assert.Empty(t, result) + assert.Error(t, err) + + err = scenario.Shutdown() + assert.NoError(t, err) +} + +// TestACLNamedHostsCanReachBySubnet is the same as +// TestACLNamedHostsCanReach, but it tests if we expand a +// full CIDR correctly. All routes should work. +func TestACLNamedHostsCanReachBySubnet(t *testing.T) { + IntegrationSkip(t) + + scenario := aclScenario(t, + headscale.ACLPolicy{ + Hosts: headscale.Hosts{ + "all": netip.MustParsePrefix("100.64.0.0/24"), + }, + ACLs: []headscale.ACL{ + // Everyone can curl test3 + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"all:*"}, + }, + }, + }, + ) + + user1Clients, err := scenario.ListTailscaleClients("user1") + assert.NoError(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + assert.NoError(t, err) + + // Test that user1 can visit all user2 + for _, client := range user1Clients { + for _, peer := range user2Clients { + fqdn, err := peer.FQDN() + assert.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s to %s", client.Hostname(), url) + + result, err := client.Curl(url) + assert.Len(t, result, 13) + assert.NoError(t, err) + } + } + + // Test that user2 can visit all user1 + for _, client := range user2Clients { + for _, peer := range user1Clients { + fqdn, err := peer.FQDN() + assert.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s to %s", client.Hostname(), url) + + result, err := client.Curl(url) + assert.Len(t, result, 13) + assert.NoError(t, err) + } + } + + err = scenario.Shutdown() + assert.NoError(t, err) +} diff --git a/machine.go b/machine.go index 146bfcf..71217ab 100644 --- a/machine.go +++ b/machine.go @@ -170,13 +170,14 @@ func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines) // filterMachinesByACL returns the list of peers authorized to be accessed from a given machine. func filterMachinesByACL( machine *Machine, - machines []Machine, + machines Machines, lock *sync.RWMutex, aclPeerCacheMap map[string]map[string]struct{}, ) Machines { log.Trace(). Caller(). - Str("machine", machine.Hostname). + Str("self", machine.Hostname). + Str("input", machines.String()). Msg("Finding peers filtered by ACLs") peers := make(map[uint64]Machine) @@ -263,7 +264,7 @@ func filterMachinesByACL( lock.RUnlock() - authorizedPeers := make([]Machine, 0, len(peers)) + authorizedPeers := make(Machines, 0, len(peers)) for _, m := range peers { authorizedPeers = append(authorizedPeers, m) } @@ -274,8 +275,9 @@ func filterMachinesByACL( log.Trace(). Caller(). - Str("machine", machine.Hostname). - Msgf("Found some machines: %v", machines) + Str("self", machine.Hostname). + Str("peers", authorizedPeers.String()). + Msg("Authorized peers") return authorizedPeers } @@ -335,8 +337,9 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", machine.Hostname). - Msgf("Found total peers: %s", peers.String()) + Str("self", machine.Hostname). + Str("peers", peers.String()). + Msg("Peers returned to caller") return peers, nil }