From 889d5a1b2982e37021546ae46a9df5c135455038 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 26 Apr 2023 17:27:51 +0200 Subject: [PATCH] testing without that horrible filtercode Signed-off-by: Kristoffer Dalby --- acls.go | 263 ++++++++++++++++++++++++++++++++----------- acls_test.go | 284 ++++++++++++++++++++++++----------------------- go.mod | 6 +- go.sum | 5 +- machine.go | 290 +++++++++++++++++++++++++++--------------------- machine_test.go | 7 +- 6 files changed, 510 insertions(+), 345 deletions(-) diff --git a/acls.go b/acls.go index 73f437b..3e146cd 100644 --- a/acls.go +++ b/acls.go @@ -132,16 +132,17 @@ func (h *Headscale) UpdateACLRules() error { if err != nil { return err } + log.Trace().Interface("ACL", rules).Msg("ACL rules generated") h.aclRules = rules // Precompute a map of which sources can reach each destination, this is // to provide quicker lookup when we calculate the peerlist for the map // response to nodes. - aclPeerCacheMap := generateACLPeerCacheMap(rules) - h.aclPeerCacheMapRW.Lock() - h.aclPeerCacheMap = aclPeerCacheMap - h.aclPeerCacheMapRW.Unlock() + // aclPeerCacheMap := generateACLPeerCacheMap(rules) + // h.aclPeerCacheMapRW.Lock() + // h.aclPeerCacheMap = aclPeerCacheMap + // h.aclPeerCacheMapRW.Unlock() if featureEnableSSH() { sshRules, err := h.generateSSHRules() @@ -160,69 +161,69 @@ func (h *Headscale) UpdateACLRules() error { return nil } -// generateACLPeerCacheMap takes a list of Tailscale filter rules and generates a map -// of which Sources ("*" and IPs) can access destinations. This is to speed up the -// process of generating MapResponses when deciding which Peers to inform nodes about. -func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string][]string { - aclCachePeerMap := make(map[string][]string) - for _, rule := range rules { - for _, srcIP := range rule.SrcIPs { - for _, ip := range expandACLPeerAddr(srcIP) { - if data, ok := aclCachePeerMap[ip]; ok { - for _, dstPort := range rule.DstPorts { - data = append(data, dstPort.IP) - } - aclCachePeerMap[ip] = data - } else { - dstPortsMap := make([]string, 0) - for _, dstPort := range rule.DstPorts { - dstPortsMap = append(dstPortsMap, dstPort.IP) - } - aclCachePeerMap[ip] = dstPortsMap - } - } - } - } - - log.Trace().Interface("ACL Cache Map", aclCachePeerMap).Msg("ACL Peer Cache Map generated") - - return aclCachePeerMap -} - -// expandACLPeerAddr takes a "tailcfg.FilterRule" "IP" and expands it into -// something our cache logic can look up, which is "*" or single IP addresses. -// This is probably quite inefficient, but it is a result of -// "make it work, then make it fast", and a lot of the ACL stuff does not -// work, but people have tried to make it fast. -func expandACLPeerAddr(srcIP string) []string { - if ip, err := netip.ParseAddr(srcIP); err == nil { - return []string{ip.String()} - } - - if cidr, err := netip.ParsePrefix(srcIP); err == nil { - addrs := []string{} - - ipRange := netipx.RangeOfPrefix(cidr) - - from := ipRange.From() - too := ipRange.To() - - if from == too { - return []string{from.String()} - } - - for from != too && from.Less(too) { - addrs = append(addrs, from.String()) - from = from.Next() - } - addrs = append(addrs, too.String()) // Add the last IP address in the range - - return addrs - } - - // probably "*" or other string based "IP" - return []string{srcIP} -} +// // generateACLPeerCacheMap takes a list of Tailscale filter rules and generates a map +// // of which Sources ("*" and IPs) can access destinations. This is to speed up the +// // process of generating MapResponses when deciding which Peers to inform nodes about. +// func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string][]string { +// aclCachePeerMap := make(map[string][]string) +// for _, rule := range rules { +// for _, srcIP := range rule.SrcIPs { +// for _, ip := range expandACLPeerAddr(srcIP) { +// if data, ok := aclCachePeerMap[ip]; ok { +// for _, dstPort := range rule.DstPorts { +// data = append(data, dstPort.IP) +// } +// aclCachePeerMap[ip] = data +// } else { +// dstPortsMap := make([]string, 0) +// for _, dstPort := range rule.DstPorts { +// dstPortsMap = append(dstPortsMap, dstPort.IP) +// } +// aclCachePeerMap[ip] = dstPortsMap +// } +// } +// } +// } +// +// log.Trace().Interface("ACL Cache Map", aclCachePeerMap).Msg("ACL Peer Cache Map generated") +// +// return aclCachePeerMap +// } +// +// // expandACLPeerAddr takes a "tailcfg.FilterRule" "IP" and expands it into +// // something our cache logic can look up, which is "*" or single IP addresses. +// // This is probably quite inefficient, but it is a result of +// // "make it work, then make it fast", and a lot of the ACL stuff does not +// // work, but people have tried to make it fast. +// func expandACLPeerAddr(srcIP string) []string { +// if ip, err := netip.ParseAddr(srcIP); err == nil { +// return []string{ip.String()} +// } +// +// if cidr, err := netip.ParsePrefix(srcIP); err == nil { +// addrs := []string{} +// +// ipRange := netipx.RangeOfPrefix(cidr) +// +// from := ipRange.From() +// too := ipRange.To() +// +// if from == too { +// return []string{from.String()} +// } +// +// for from != too && from.Less(too) { +// addrs = append(addrs, from.String()) +// from = from.Next() +// } +// addrs = append(addrs, too.String()) // Add the last IP address in the range +// +// return addrs +// } +// +// // probably "*" or other string based "IP" +// return []string{srcIP} +// } // generateFilterRules takes a set of machines and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. @@ -879,3 +880,131 @@ func (pol *ACLPolicy) getIPsFromIPPrefix( return lo.Uniq(val), nil } + +// This is borrowed from +// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 +var ( + zeroIP4 = netip.AddrFrom4([4]byte{}) + zeroIP6 = netip.AddrFrom16([16]byte{}) +) + +// parseIPSet parses arg as one: +// +// - an IP address (IPv4 or IPv6) +// - the string "*" to match everything (both IPv4 & IPv6) +// - a CIDR (e.g. "192.168.0.0/16") +// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") +// +// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP +// address (without a slash) treated as a CIDR of *bits length. +// +// TODO(bradfitz): make this return an IPSet and plumb that all +// around, and ultimately use a new version of IPSet.ContainsFunc like +// Contains16Func that works in [16]byte address, so we we can match +// at runtime without allocating? +func parseIPSet(arg string, bits *int) ([]netip.Prefix, error) { + if arg == "*" { + // User explicitly requested wildcard. + return []netip.Prefix{ + netip.PrefixFrom(zeroIP4, 0), + netip.PrefixFrom(zeroIP6, 0), + }, nil + } + if strings.Contains(arg, "/") { + pfx, err := netip.ParsePrefix(arg) + if err != nil { + return nil, err + } + if pfx != pfx.Masked() { + return nil, fmt.Errorf("%v contains non-network bits set", pfx) + } + return []netip.Prefix{pfx}, nil + } + if strings.Count(arg, "-") == 1 { + ip1s, ip2s, _ := strings.Cut(arg, "-") + ip1, err := netip.ParseAddr(ip1s) + if err != nil { + return nil, err + } + ip2, err := netip.ParseAddr(ip2s) + if err != nil { + return nil, err + } + r := netipx.IPRangeFrom(ip1, ip2) + if !r.Valid() { + return nil, fmt.Errorf("invalid IP range %q", arg) + } + return r.Prefixes(), nil + } + ip, err := netip.ParseAddr(arg) + if err != nil { + return nil, fmt.Errorf("invalid IP address %q", arg) + } + bits8 := uint8(ip.BitLen()) + if bits != nil { + if *bits < 0 || *bits > int(bits8) { + return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg) + } + bits8 = uint8(*bits) + } + return []netip.Prefix{netip.PrefixFrom(ip, int(bits8))}, nil +} + +func ipInPrefixList(ip netip.Addr, netlist []netip.Prefix) bool { + for _, net := range netlist { + if net.Contains(ip) { + return true + } + } + return false +} + +type Match struct { + Srcs []netip.Prefix + Dests []netip.Prefix +} + +func MatchFromFilterRule(rule tailcfg.FilterRule) Match { + match := Match{ + Srcs: []netip.Prefix{}, + Dests: []netip.Prefix{}, + } + + for _, srcIP := range rule.SrcIPs { + prefix, _ := parseIPSet(srcIP, nil) + + match.Srcs = append(match.Srcs, prefix...) + } + + for _, dest := range rule.DstPorts { + prefix, _ := parseIPSet(dest.IP, nil) + + match.Dests = append(match.Dests, prefix...) + } + + return match +} + +func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { + for _, prefix := range m.Srcs { + for _, ip := range ips { + if prefix.Contains(ip) { + return true + } + } + } + + return false +} + +func (m *Match) DestsContainsIP(ips []netip.Addr) bool { + for _, prefix := range m.Dests { + for _, ip := range ips { + if prefix.Contains(ip) { + return true + } + } + } + + return false +} diff --git a/acls_test.go b/acls_test.go index f96ac17..fef92e4 100644 --- a/acls_test.go +++ b/acls_test.go @@ -1661,140 +1661,140 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { } } -func Test_expandACLPeerAddr(t *testing.T) { - type args struct { - srcIP string - } - tests := []struct { - name string - args args - want []string - }{ - { - name: "asterix", - args: args{ - srcIP: "*", - }, - want: []string{"*"}, - }, - { - name: "ip", - args: args{ - srcIP: "10.0.0.1", - }, - want: []string{"10.0.0.1"}, - }, - { - name: "ip/32", - args: args{ - srcIP: "10.0.0.1/32", - }, - want: []string{"10.0.0.1"}, - }, - { - name: "ip/30", - args: args{ - srcIP: "10.0.0.1/30", - }, - want: []string{ - "10.0.0.0", - "10.0.0.1", - "10.0.0.2", - "10.0.0.3", - }, - }, - { - name: "ip/28", - args: args{ - srcIP: "192.168.0.128/28", - }, - want: []string{ - "192.168.0.128", "192.168.0.129", "192.168.0.130", - "192.168.0.131", "192.168.0.132", "192.168.0.133", - "192.168.0.134", "192.168.0.135", "192.168.0.136", - "192.168.0.137", "192.168.0.138", "192.168.0.139", - "192.168.0.140", "192.168.0.141", "192.168.0.142", - "192.168.0.143", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := expandACLPeerAddr(tt.args.srcIP); !reflect.DeepEqual(got, tt.want) { - t.Errorf("expandACLPeerAddr() = %v, want %v", got, tt.want) - } - }) - } -} +// func Test_expandACLPeerAddr(t *testing.T) { +// type args struct { +// srcIP string +// } +// tests := []struct { +// name string +// args args +// want []string +// }{ +// { +// name: "asterix", +// args: args{ +// srcIP: "*", +// }, +// want: []string{"*"}, +// }, +// { +// name: "ip", +// args: args{ +// srcIP: "10.0.0.1", +// }, +// want: []string{"10.0.0.1"}, +// }, +// { +// name: "ip/32", +// args: args{ +// srcIP: "10.0.0.1/32", +// }, +// want: []string{"10.0.0.1"}, +// }, +// { +// name: "ip/30", +// args: args{ +// srcIP: "10.0.0.1/30", +// }, +// want: []string{ +// "10.0.0.0", +// "10.0.0.1", +// "10.0.0.2", +// "10.0.0.3", +// }, +// }, +// { +// name: "ip/28", +// args: args{ +// srcIP: "192.168.0.128/28", +// }, +// want: []string{ +// "192.168.0.128", "192.168.0.129", "192.168.0.130", +// "192.168.0.131", "192.168.0.132", "192.168.0.133", +// "192.168.0.134", "192.168.0.135", "192.168.0.136", +// "192.168.0.137", "192.168.0.138", "192.168.0.139", +// "192.168.0.140", "192.168.0.141", "192.168.0.142", +// "192.168.0.143", +// }, +// }, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// if got := expandACLPeerAddr(tt.args.srcIP); !reflect.DeepEqual(got, tt.want) { +// t.Errorf("expandACLPeerAddr() = %v, want %v", got, tt.want) +// } +// }) +// } +// } -func Test_expandACLPeerAddrV6(t *testing.T) { - type args struct { - srcIP string - } - tests := []struct { - name string - args args - want []string - }{ - { - name: "asterix", - args: args{ - srcIP: "*", - }, - want: []string{"*"}, - }, - { - name: "ipfull", - args: args{ - srcIP: "fd7a:115c:a1e0:ab12:4943:cd96:624c:3166", - }, - want: []string{"fd7a:115c:a1e0:ab12:4943:cd96:624c:3166"}, - }, - { - name: "ipzerocompression", - args: args{ - srcIP: "fd7a:115c:a1e0:ab12:4943:cd96:624c::", - }, - want: []string{"fd7a:115c:a1e0:ab12:4943:cd96:624c:0"}, - }, - { - name: "ip/128", - args: args{ - srcIP: "fd7a:115c:a1e0:ab12:4943:cd96:624c:3166/128", - }, - want: []string{"fd7a:115c:a1e0:ab12:4943:cd96:624c:3166"}, - }, - { - name: "ip/127", - args: args{ - srcIP: "fd7a:115c:a1e0:ab12:4943:cd96:624c:0000/127", - }, - want: []string{ - "fd7a:115c:a1e0:ab12:4943:cd96:624c:0", - "fd7a:115c:a1e0:ab12:4943:cd96:624c:1", - }, - }, - { - name: "ip/126", - args: args{ - srcIP: "fd7a:115c:a1e0:ab12:4943:cd96:624c:0000/126", - }, - want: []string{ - "fd7a:115c:a1e0:ab12:4943:cd96:624c:0", - "fd7a:115c:a1e0:ab12:4943:cd96:624c:1", - "fd7a:115c:a1e0:ab12:4943:cd96:624c:2", - "fd7a:115c:a1e0:ab12:4943:cd96:624c:3", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := expandACLPeerAddr(tt.args.srcIP); !reflect.DeepEqual(got, tt.want) { - t.Errorf("expandACLPeerAddr() = %v, want %v", got, tt.want) - } - }) - } -} +// func Test_expandACLPeerAddrV6(t *testing.T) { +// type args struct { +// srcIP string +// } +// tests := []struct { +// name string +// args args +// want []string +// }{ +// { +// name: "asterix", +// args: args{ +// srcIP: "*", +// }, +// want: []string{"*"}, +// }, +// { +// name: "ipfull", +// args: args{ +// srcIP: "fd7a:115c:a1e0:ab12:4943:cd96:624c:3166", +// }, +// want: []string{"fd7a:115c:a1e0:ab12:4943:cd96:624c:3166"}, +// }, +// { +// name: "ipzerocompression", +// args: args{ +// srcIP: "fd7a:115c:a1e0:ab12:4943:cd96:624c::", +// }, +// want: []string{"fd7a:115c:a1e0:ab12:4943:cd96:624c:0"}, +// }, +// { +// name: "ip/128", +// args: args{ +// srcIP: "fd7a:115c:a1e0:ab12:4943:cd96:624c:3166/128", +// }, +// want: []string{"fd7a:115c:a1e0:ab12:4943:cd96:624c:3166"}, +// }, +// { +// name: "ip/127", +// args: args{ +// srcIP: "fd7a:115c:a1e0:ab12:4943:cd96:624c:0000/127", +// }, +// want: []string{ +// "fd7a:115c:a1e0:ab12:4943:cd96:624c:0", +// "fd7a:115c:a1e0:ab12:4943:cd96:624c:1", +// }, +// }, +// { +// name: "ip/126", +// args: args{ +// srcIP: "fd7a:115c:a1e0:ab12:4943:cd96:624c:0000/126", +// }, +// want: []string{ +// "fd7a:115c:a1e0:ab12:4943:cd96:624c:0", +// "fd7a:115c:a1e0:ab12:4943:cd96:624c:1", +// "fd7a:115c:a1e0:ab12:4943:cd96:624c:2", +// "fd7a:115c:a1e0:ab12:4943:cd96:624c:3", +// }, +// }, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// if got := expandACLPeerAddr(tt.args.srcIP); !reflect.DeepEqual(got, tt.want) { +// t.Errorf("expandACLPeerAddr() = %v, want %v", got, tt.want) +// } +// }) +// } +// } func TestACLPolicy_generateFilterRules(t *testing.T) { type field struct { @@ -1819,7 +1819,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { wantErr: false, }, { - name: "simple group", + name: "allow-all", field: field{ pol: ACLPolicy{ ACLs: []ACL{ @@ -1852,7 +1852,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { wantErr: false, }, { - name: "simple host by ipv4 single dual stack", + name: "host1-can-reach-host2", field: field{ pol: ACLPolicy{ ACLs: []ACL{ @@ -1868,14 +1868,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { machines: []Machine{ { IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), }, User: User{Name: "mickael"}, }, { IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.2"), + netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), }, User: User{Name: "mickael"}, @@ -1883,10 +1883,9 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, stripEmailDomain: true, }, - // [{"SrcIPs":["100.64.0.1"],"DstPorts":[{"IP":"100.64.0.2","Bits":null,"Ports":{"First":0,"Last":65535}}]}] want: []tailcfg.FilterRule{ { - SrcIPs: []string{"100.64.0.1"}, + SrcIPs: []string{"100.64.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2221"}, DstPorts: []tailcfg.NetPortRange{ { IP: "100.64.0.2", @@ -1895,6 +1894,13 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { Last: 65535, }, }, + { + IP: "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", + Ports: tailcfg.PortRange{ + First: 0, + Last: 65535, + }, + }, }, }, }, diff --git a/go.mod b/go.mod index b2c1295..e26e61a 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.20 require ( github.com/AlecAivazis/survey/v2 v2.3.6 - github.com/ccding/go-stun/stun v0.0.0-20200514191101-4dc67bcdb029 github.com/cenkalti/backoff/v4 v4.2.0 github.com/coreos/go-oidc/v3 v3.5.0 github.com/davecgh/go-spew v1.1.1 @@ -12,6 +11,7 @@ require ( github.com/efekarakus/termcolor v1.0.1 github.com/glebarez/sqlite v1.7.0 github.com/gofrs/uuid/v5 v5.0.0 + github.com/google/go-cmp v0.5.9 github.com/gorilla/mux v1.8.0 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.15.2 @@ -20,6 +20,7 @@ require ( github.com/ory/dockertest/v3 v3.9.1 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/philip-bui/grpc-zerolog v1.0.1 + github.com/pkg/profile v1.7.0 github.com/prometheus/client_golang v1.14.0 github.com/prometheus/common v0.42.0 github.com/pterm/pterm v0.12.58 @@ -73,7 +74,6 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-github v17.0.0+incompatible // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 // indirect @@ -113,7 +113,6 @@ require ( github.com/opencontainers/runc v1.1.4 // indirect github.com/pelletier/go-toml/v2 v2.0.6 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/pkg/profile v1.7.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.3.0 // indirect github.com/prometheus/procfs v0.9.0 // indirect @@ -144,6 +143,7 @@ require ( gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gotest.tools/v3 v3.4.0 // indirect modernc.org/libc v1.22.2 // indirect modernc.org/mathutil v1.5.0 // indirect modernc.org/memory v1.5.0 // indirect diff --git a/go.sum b/go.sum index 26274d6..cf73d72 100644 --- a/go.sum +++ b/go.sum @@ -74,8 +74,6 @@ github.com/atomicgo/cursor v0.0.1/go.mod h1:cBON2QmmrysudxNBFthvMtN32r3jxVRIvzkU github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/ccding/go-stun/stun v0.0.0-20200514191101-4dc67bcdb029 h1:POmUHfxXdeyM8Aomg4tKDcwATCFuW+cYLkj6pwsw9pc= -github.com/ccding/go-stun/stun v0.0.0-20200514191101-4dc67bcdb029/go.mod h1:Rpr5n9cGHYdM3S3IK8ROSUUUYjQOu+MSUCZDcJbYWi8= github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4= github.com/cenkalti/backoff/v4 v4.2.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -902,7 +900,8 @@ gorm.io/driver/postgres v1.4.8/go.mod h1:O9MruWGNLUBUWVYfWuBClpf3HeGjOoybY0SNmCs gorm.io/gorm v1.24.2/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= gorm.io/gorm v1.24.6 h1:wy98aq9oFEetsc4CAbKD2SoBCdMzsbSIvSUUFJuHi5s= gorm.io/gorm v1.24.6/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= -gotest.tools/v3 v3.2.0 h1:I0DwBVMGAx26dttAj1BtJLAkVGncrkkUXfJLC4Flt/I= +gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= +gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/machine.go b/machine.go index 1d13275..3da4180 100644 --- a/machine.go +++ b/machine.go @@ -4,12 +4,10 @@ import ( "database/sql/driver" "errors" "fmt" - "net" "net/netip" "sort" "strconv" "strings" - "sync" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -162,151 +160,189 @@ func (machine *Machine) isEphemeral() bool { return machine.AuthKey != nil && machine.AuthKey.Ephemeral } +func (machine *Machine) canAccess(filter []tailcfg.FilterRule, machine2 *Machine) bool { + for _, rule := range filter { + // TODO(kradalby): Cache or pregen this + matcher := MatchFromFilterRule(rule) + + if !matcher.SrcsContainsIPs([]netip.Addr(machine.IPAddresses)) { + continue + } + + if matcher.DestsContainsIP([]netip.Addr(machine2.IPAddresses)) { + return true + } + } + + return false +} + // filterMachinesByACL wrapper function to not have devs pass around locks and maps // related to the application outside of tests. func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines) Machines { - return filterMachinesByACL(currentMachine, peers, &h.aclPeerCacheMapRW, h.aclPeerCacheMap) + return filterMachinesByACL(currentMachine, peers, h.aclRules) } // filterMachinesByACL returns the list of peers authorized to be accessed from a given machine. func filterMachinesByACL( machine *Machine, machines Machines, - lock *sync.RWMutex, - aclPeerCacheMap map[string][]string, + filter []tailcfg.FilterRule, ) Machines { - log.Trace(). - Caller(). - Str("self", machine.Hostname). - Str("input", machines.String()). - Msg("Finding peers filtered by ACLs") + result := Machines{} - peers := make(map[uint64]Machine) - // Aclfilter peers here. We are itering through machines in all users and search through the computed aclRules - // for match between rule SrcIPs and DstPorts. If the rule is a match we allow the machine to be viewable. - machineIPs := machine.IPAddresses.ToStringSlice() - - // TODO(kradalby): Remove this lock, I suspect its not a good idea, and might not be necessary, - // we only set this at startup atm (reading ACLs) and it might become a bottleneck. - lock.RLock() - - for _, peer := range machines { + for index, peer := range machines { if peer.ID == machine.ID { continue } - peerIPs := peer.IPAddresses.ToStringSlice() - if dstMap, ok := aclPeerCacheMap["*"]; ok { - // match source and all destination - - for _, dst := range dstMap { - if dst == "*" { - peers[peer.ID] = peer - - continue - } - } - - // match source and all destination - for _, peerIP := range peerIPs { - for _, dst := range dstMap { - _, cdr, _ := net.ParseCIDR(dst) - ip := net.ParseIP(peerIP) - if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { - peers[peer.ID] = peer - - continue - } - } - } - - // match all sources and source - for _, machineIP := range machineIPs { - for _, dst := range dstMap { - _, cdr, _ := net.ParseCIDR(dst) - ip := net.ParseIP(machineIP) - if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { - peers[peer.ID] = peer - - continue - } - } - } - } - - for _, machineIP := range machineIPs { - if dstMap, ok := aclPeerCacheMap[machineIP]; ok { - // match source and all destination - for _, dst := range dstMap { - if dst == "*" { - peers[peer.ID] = peer - - continue - } - } - - // match source and destination - for _, peerIP := range peerIPs { - for _, dst := range dstMap { - _, cdr, _ := net.ParseCIDR(dst) - ip := net.ParseIP(peerIP) - if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { - peers[peer.ID] = peer - - continue - } - } - } - } - } - - for _, peerIP := range peerIPs { - if dstMap, ok := aclPeerCacheMap[peerIP]; ok { - // match source and all destination - for _, dst := range dstMap { - if dst == "*" { - peers[peer.ID] = peer - - continue - } - } - - // match return path - for _, machineIP := range machineIPs { - for _, dst := range dstMap { - _, cdr, _ := net.ParseCIDR(dst) - ip := net.ParseIP(machineIP) - if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { - peers[peer.ID] = peer - - continue - } - } - } - } + if machine.canAccess(filter, &machines[index]) || peer.canAccess(filter, machine) { + result = append(result, peer) } } - lock.RUnlock() - - authorizedPeers := make(Machines, 0, len(peers)) - for _, m := range peers { - authorizedPeers = append(authorizedPeers, m) - } - sort.Slice( - authorizedPeers, - func(i, j int) bool { return authorizedPeers[i].ID < authorizedPeers[j].ID }, - ) - - log.Trace(). - Caller(). - Str("self", machine.Hostname). - Str("peers", authorizedPeers.String()). - Msg("Authorized peers") - - return authorizedPeers + return result } +// filterMachinesByACL returns the list of peers authorized to be accessed from a given machine. +// func filterMachinesByACL( +// machine *Machine, +// machines Machines, +// lock *sync.RWMutex, +// aclPeerCacheMap map[string][]string, +// ) Machines { +// log.Trace(). +// Caller(). +// Str("self", machine.Hostname). +// Str("input", machines.String()). +// Msg("Finding peers filtered by ACLs") +// +// peers := make(map[uint64]Machine) +// // Aclfilter peers here. We are itering through machines in all users and search through the computed aclRules +// // for match between rule SrcIPs and DstPorts. If the rule is a match we allow the machine to be viewable. +// machineIPs := machine.IPAddresses.ToStringSlice() +// +// // TODO(kradalby): Remove this lock, I suspect its not a good idea, and might not be necessary, +// // we only set this at startup atm (reading ACLs) and it might become a bottleneck. +// lock.RLock() +// +// for _, peer := range machines { +// if peer.ID == machine.ID { +// continue +// } +// peerIPs := peer.IPAddresses.ToStringSlice() +// +// if dstMap, ok := aclPeerCacheMap["*"]; ok { +// // match source and all destination +// +// for _, dst := range dstMap { +// if dst == "*" { +// peers[peer.ID] = peer +// +// continue +// } +// } +// +// // match source and all destination +// for _, peerIP := range peerIPs { +// for _, dst := range dstMap { +// _, cdr, _ := net.ParseCIDR(dst) +// ip := net.ParseIP(peerIP) +// if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { +// peers[peer.ID] = peer +// +// continue +// } +// } +// } +// +// // match all sources and source +// for _, machineIP := range machineIPs { +// for _, dst := range dstMap { +// _, cdr, _ := net.ParseCIDR(dst) +// ip := net.ParseIP(machineIP) +// if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { +// peers[peer.ID] = peer +// +// continue +// } +// } +// } +// } +// +// for _, machineIP := range machineIPs { +// if dstMap, ok := aclPeerCacheMap[machineIP]; ok { +// // match source and all destination +// for _, dst := range dstMap { +// if dst == "*" { +// peers[peer.ID] = peer +// +// continue +// } +// } +// +// // match source and destination +// for _, peerIP := range peerIPs { +// for _, dst := range dstMap { +// _, cdr, _ := net.ParseCIDR(dst) +// ip := net.ParseIP(peerIP) +// if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { +// peers[peer.ID] = peer +// +// continue +// } +// } +// } +// } +// } +// +// for _, peerIP := range peerIPs { +// if dstMap, ok := aclPeerCacheMap[peerIP]; ok { +// // match source and all destination +// for _, dst := range dstMap { +// if dst == "*" { +// peers[peer.ID] = peer +// +// continue +// } +// } +// +// // match return path +// for _, machineIP := range machineIPs { +// for _, dst := range dstMap { +// _, cdr, _ := net.ParseCIDR(dst) +// ip := net.ParseIP(machineIP) +// if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { +// peers[peer.ID] = peer +// +// continue +// } +// } +// } +// } +// } +// } +// +// lock.RUnlock() +// +// authorizedPeers := make(Machines, 0, len(peers)) +// for _, m := range peers { +// authorizedPeers = append(authorizedPeers, m) +// } +// sort.Slice( +// authorizedPeers, +// func(i, j int) bool { return authorizedPeers[i].ID < authorizedPeers[j].ID }, +// ) +// +// log.Trace(). +// Caller(). +// Str("self", machine.Hostname). +// Str("peers", authorizedPeers.String()). +// Msg("Authorized peers") +// +// return authorizedPeers +// } + func (h *Headscale) ListPeers(machine *Machine) (Machines, error) { log.Trace(). Caller(). diff --git a/machine_test.go b/machine_test.go index c25f32d..445fe95 100644 --- a/machine_test.go +++ b/machine_test.go @@ -6,7 +6,6 @@ import ( "reflect" "regexp" "strconv" - "sync" "testing" "time" @@ -1041,16 +1040,12 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, }, } - var lock sync.RWMutex for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - aclRulesMap := generateACLPeerCacheMap(tt.args.rules) - got := filterMachinesByACL( tt.args.machine, tt.args.machines, - &lock, - aclRulesMap, + tt.args.rules, ) if !reflect.DeepEqual(got, tt.want) { t.Errorf("filterMachinesByACL() = %v, want %v", got, tt.want)