diff --git a/acls.go b/acls.go index c197c7c..f405acb 100644 --- a/acls.go +++ b/acls.go @@ -15,7 +15,6 @@ import ( "github.com/rs/zerolog/log" "github.com/samber/lo" "github.com/tailscale/hujson" - "go4.org/netipx" "gopkg.in/yaml.v3" "tailscale.com/envknob" "tailscale.com/tailcfg" @@ -808,135 +807,3 @@ func (pol *ACLPolicy) getIPsFromIPPrefix( return lo.Uniq(val), nil } - -// This is borrowed from, and updated to use IPSet -// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 -// TODO(kradalby): contribute upstream and make public -var ( - zeroIP4 = netip.AddrFrom4([4]byte{}) - zeroIP6 = netip.AddrFrom16([16]byte{}) -) - -// parseIPSet parses arg as one: -// -// - an IP address (IPv4 or IPv6) -// - the string "*" to match everything (both IPv4 & IPv6) -// - a CIDR (e.g. "192.168.0.0/16") -// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") -// -// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP -// address (without a slash) treated as a CIDR of *bits length. -// nolint -func parseIPSet(arg string, bits *int) (*netipx.IPSet, error) { - var ipSet netipx.IPSetBuilder - if arg == "*" { - ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0)) - ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0)) - - return ipSet.IPSet() - } - if strings.Contains(arg, "/") { - pfx, err := netip.ParsePrefix(arg) - if err != nil { - return nil, err - } - if pfx != pfx.Masked() { - return nil, fmt.Errorf("%v contains non-network bits set", pfx) - } - - ipSet.AddPrefix(pfx) - - return ipSet.IPSet() - } - if strings.Count(arg, "-") == 1 { - ip1s, ip2s, _ := strings.Cut(arg, "-") - - ip1, err := netip.ParseAddr(ip1s) - if err != nil { - return nil, err - } - - ip2, err := netip.ParseAddr(ip2s) - if err != nil { - return nil, err - } - - r := netipx.IPRangeFrom(ip1, ip2) - if !r.IsValid() { - return nil, fmt.Errorf("invalid IP range %q", arg) - } - - for _, prefix := range r.Prefixes() { - ipSet.AddPrefix(prefix) - } - - return ipSet.IPSet() - } - ip, err := netip.ParseAddr(arg) - if err != nil { - return nil, fmt.Errorf("invalid IP address %q", arg) - } - bits8 := uint8(ip.BitLen()) - if bits != nil { - if *bits < 0 || *bits > int(bits8) { - return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg) - } - bits8 = uint8(*bits) - } - - ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8))) - - return ipSet.IPSet() -} - -type Match struct { - Srcs *netipx.IPSet - Dests *netipx.IPSet -} - -func MatchFromFilterRule(rule tailcfg.FilterRule) Match { - srcs := new(netipx.IPSetBuilder) - dests := new(netipx.IPSetBuilder) - - for _, srcIP := range rule.SrcIPs { - set, _ := parseIPSet(srcIP, nil) - - srcs.AddSet(set) - } - - for _, dest := range rule.DstPorts { - set, _ := parseIPSet(dest.IP, nil) - - dests.AddSet(set) - } - - srcsSet, _ := srcs.IPSet() - destsSet, _ := dests.IPSet() - - match := Match{ - Srcs: srcsSet, - Dests: destsSet, - } - - return match -} - -func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { - for _, ip := range ips { - if m.Srcs.Contains(ip) { - return true - } - } - - return false -} - -func (m *Match) DestsContainsIP(ips []netip.Addr) bool { - for _, ip := range ips { - if m.Dests.Contains(ip) { - return true - } - } - - return false -} diff --git a/matcher.go b/matcher.go new file mode 100644 index 0000000..1a186c4 --- /dev/null +++ b/matcher.go @@ -0,0 +1,142 @@ +package headscale + +import ( + "fmt" + "net/netip" + "strings" + + "go4.org/netipx" + "tailscale.com/tailcfg" +) + +// This is borrowed from, and updated to use IPSet +// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 +// TODO(kradalby): contribute upstream and make public. +var ( + zeroIP4 = netip.AddrFrom4([4]byte{}) + zeroIP6 = netip.AddrFrom16([16]byte{}) +) + +// parseIPSet parses arg as one: +// +// - an IP address (IPv4 or IPv6) +// - the string "*" to match everything (both IPv4 & IPv6) +// - a CIDR (e.g. "192.168.0.0/16") +// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") +// +// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP +// address (without a slash) treated as a CIDR of *bits length. +// nolint +func parseIPSet(arg string, bits *int) (*netipx.IPSet, error) { + var ipSet netipx.IPSetBuilder + if arg == "*" { + ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0)) + ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0)) + + return ipSet.IPSet() + } + if strings.Contains(arg, "/") { + pfx, err := netip.ParsePrefix(arg) + if err != nil { + return nil, err + } + if pfx != pfx.Masked() { + return nil, fmt.Errorf("%v contains non-network bits set", pfx) + } + + ipSet.AddPrefix(pfx) + + return ipSet.IPSet() + } + if strings.Count(arg, "-") == 1 { + ip1s, ip2s, _ := strings.Cut(arg, "-") + + ip1, err := netip.ParseAddr(ip1s) + if err != nil { + return nil, err + } + + ip2, err := netip.ParseAddr(ip2s) + if err != nil { + return nil, err + } + + r := netipx.IPRangeFrom(ip1, ip2) + if !r.IsValid() { + return nil, fmt.Errorf("invalid IP range %q", arg) + } + + for _, prefix := range r.Prefixes() { + ipSet.AddPrefix(prefix) + } + + return ipSet.IPSet() + } + ip, err := netip.ParseAddr(arg) + if err != nil { + return nil, fmt.Errorf("invalid IP address %q", arg) + } + bits8 := uint8(ip.BitLen()) + if bits != nil { + if *bits < 0 || *bits > int(bits8) { + return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg) + } + bits8 = uint8(*bits) + } + + ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8))) + + return ipSet.IPSet() +} + +type Match struct { + Srcs *netipx.IPSet + Dests *netipx.IPSet +} + +func MatchFromFilterRule(rule tailcfg.FilterRule) Match { + srcs := new(netipx.IPSetBuilder) + dests := new(netipx.IPSetBuilder) + + for _, srcIP := range rule.SrcIPs { + set, _ := parseIPSet(srcIP, nil) + + srcs.AddSet(set) + } + + for _, dest := range rule.DstPorts { + set, _ := parseIPSet(dest.IP, nil) + + dests.AddSet(set) + } + + srcsSet, _ := srcs.IPSet() + destsSet, _ := dests.IPSet() + + match := Match{ + Srcs: srcsSet, + Dests: destsSet, + } + + return match +} + +func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { + for _, ip := range ips { + if m.Srcs.Contains(ip) { + return true + } + } + + return false +} + +func (m *Match) DestsContainsIP(ips []netip.Addr) bool { + for _, ip := range ips { + if m.Dests.Contains(ip) { + return true + } + } + + return false +} diff --git a/matcher_test.go b/matcher_test.go new file mode 100644 index 0000000..03b585c --- /dev/null +++ b/matcher_test.go @@ -0,0 +1,119 @@ +package headscale + +import ( + "net/netip" + "reflect" + "testing" + + "go4.org/netipx" +) + +func Test_parseIPSet(t *testing.T) { + set := func(ips []string, prefixes []string) *netipx.IPSet { + var builder netipx.IPSetBuilder + + for _, ip := range ips { + builder.Add(netip.MustParseAddr(ip)) + } + + for _, pre := range prefixes { + builder.AddPrefix(netip.MustParsePrefix(pre)) + } + + s, _ := builder.IPSet() + + return s + } + + type args struct { + arg string + bits *int + } + tests := []struct { + name string + args args + want *netipx.IPSet + wantErr bool + }{ + { + name: "simple ip4", + args: args{ + arg: "10.0.0.1", + bits: nil, + }, + want: set([]string{ + "10.0.0.1", + }, []string{}), + wantErr: false, + }, + { + name: "simple ip6", + args: args{ + arg: "2001:db8:abcd:1234::2", + bits: nil, + }, + want: set([]string{ + "2001:db8:abcd:1234::2", + }, []string{}), + wantErr: false, + }, + { + name: "wildcard", + args: args{ + arg: "*", + bits: nil, + }, + want: set([]string{}, []string{ + "0.0.0.0/0", + "::/0", + }), + wantErr: false, + }, + { + name: "prefix4", + args: args{ + arg: "192.168.0.0/16", + bits: nil, + }, + want: set([]string{}, []string{ + "192.168.0.0/16", + }), + wantErr: false, + }, + { + name: "prefix6", + args: args{ + arg: "2001:db8:abcd:1234::/64", + bits: nil, + }, + want: set([]string{}, []string{ + "2001:db8:abcd:1234::/64", + }), + wantErr: false, + }, + { + name: "range4", + args: args{ + arg: "192.168.0.0-192.168.255.255", + bits: nil, + }, + want: set([]string{}, []string{ + "192.168.0.0/16", + }), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseIPSet(tt.args.arg, tt.args.bits) + if (err != nil) != tt.wantErr { + t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr) + + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseIPSet() = %v, want %v", got, tt.want) + } + }) + } +}