diff --git a/acls.go b/acls.go index cbe2f71..c197c7c 100644 --- a/acls.go +++ b/acls.go @@ -809,8 +809,9 @@ func (pol *ACLPolicy) getIPsFromIPPrefix( return lo.Uniq(val), nil } -// This is borrowed from +// This is borrowed from, and updated to use IPSet // https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 +// TODO(kradalby): contribute upstream and make public var ( zeroIP4 = netip.AddrFrom4([4]byte{}) zeroIP6 = netip.AddrFrom16([16]byte{}) @@ -825,19 +826,14 @@ var ( // // bits, if non-nil, is the legacy SrcBits CIDR length to make a IP // address (without a slash) treated as a CIDR of *bits length. -// -// TODO(bradfitz): make this return an IPSet and plumb that all -// around, and ultimately use a new version of IPSet.ContainsFunc like -// Contains16Func that works in [16]byte address, so we we can match -// at runtime without allocating? // nolint -func parseIPSet(arg string, bits *int) ([]netip.Prefix, error) { +func parseIPSet(arg string, bits *int) (*netipx.IPSet, error) { + var ipSet netipx.IPSetBuilder if arg == "*" { - // User explicitly requested wildcard. - return []netip.Prefix{ - netip.PrefixFrom(zeroIP4, 0), - netip.PrefixFrom(zeroIP6, 0), - }, nil + ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0)) + ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0)) + + return ipSet.IPSet() } if strings.Contains(arg, "/") { pfx, err := netip.ParsePrefix(arg) @@ -848,7 +844,9 @@ func parseIPSet(arg string, bits *int) ([]netip.Prefix, error) { return nil, fmt.Errorf("%v contains non-network bits set", pfx) } - return []netip.Prefix{pfx}, nil + ipSet.AddPrefix(pfx) + + return ipSet.IPSet() } if strings.Count(arg, "-") == 1 { ip1s, ip2s, _ := strings.Cut(arg, "-") @@ -868,7 +866,11 @@ func parseIPSet(arg string, bits *int) ([]netip.Prefix, error) { return nil, fmt.Errorf("invalid IP range %q", arg) } - return r.Prefixes(), nil + for _, prefix := range r.Prefixes() { + ipSet.AddPrefix(prefix) + } + + return ipSet.IPSet() } ip, err := netip.ParseAddr(arg) if err != nil { @@ -882,41 +884,47 @@ func parseIPSet(arg string, bits *int) ([]netip.Prefix, error) { bits8 = uint8(*bits) } - return []netip.Prefix{netip.PrefixFrom(ip, int(bits8))}, nil + ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8))) + + return ipSet.IPSet() } type Match struct { - Srcs []netip.Prefix - Dests []netip.Prefix + Srcs *netipx.IPSet + Dests *netipx.IPSet } func MatchFromFilterRule(rule tailcfg.FilterRule) Match { - match := Match{ - Srcs: []netip.Prefix{}, - Dests: []netip.Prefix{}, - } + srcs := new(netipx.IPSetBuilder) + dests := new(netipx.IPSetBuilder) for _, srcIP := range rule.SrcIPs { - prefix, _ := parseIPSet(srcIP, nil) + set, _ := parseIPSet(srcIP, nil) - match.Srcs = append(match.Srcs, prefix...) + srcs.AddSet(set) } for _, dest := range rule.DstPorts { - prefix, _ := parseIPSet(dest.IP, nil) + set, _ := parseIPSet(dest.IP, nil) - match.Dests = append(match.Dests, prefix...) + dests.AddSet(set) + } + + srcsSet, _ := srcs.IPSet() + destsSet, _ := dests.IPSet() + + match := Match{ + Srcs: srcsSet, + Dests: destsSet, } return match } func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { - for _, prefix := range m.Srcs { - for _, ip := range ips { - if prefix.Contains(ip) { - return true - } + for _, ip := range ips { + if m.Srcs.Contains(ip) { + return true } } @@ -924,11 +932,9 @@ func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { } func (m *Match) DestsContainsIP(ips []netip.Addr) bool { - for _, prefix := range m.Dests { - for _, ip := range ips { - if prefix.Contains(ip) { - return true - } + for _, ip := range ips { + if m.Dests.Contains(ip) { + return true } }