From cdf48b12166b5ae9e37c2d4f0c0dc3704eccb260 Mon Sep 17 00:00:00 2001 From: Juan Font Alonso Date: Fri, 2 Sep 2022 00:05:18 +0200 Subject: [PATCH] Migrate utils to net/netip --- utils.go | 35 ++++++++++++++++++----------------- utils_test.go | 23 +++++++++++++---------- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/utils.go b/utils.go index e7fb13a..666683c 100644 --- a/utils.go +++ b/utils.go @@ -13,6 +13,7 @@ import ( "fmt" "io/fs" "net" + "net/netip" "os" "path/filepath" "reflect" @@ -21,7 +22,7 @@ import ( "github.com/rs/zerolog/log" "github.com/spf13/viper" - "inet.af/netaddr" + "go4.org/netipx" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -140,7 +141,7 @@ func (h *Headscale) getAvailableIPs() (MachineAddresses, error) { var err error ipPrefixes := h.cfg.IPPrefixes for _, ipPrefix := range ipPrefixes { - var ip *netaddr.IP + var ip *netip.Addr ip, err = h.getAvailableIP(ipPrefix) if err != nil { return ips, err @@ -151,16 +152,16 @@ func (h *Headscale) getAvailableIPs() (MachineAddresses, error) { return ips, err } -func GetIPPrefixEndpoints(na netaddr.IPPrefix) (netaddr.IP, netaddr.IP) { - var network, broadcast netaddr.IP - ipRange := na.Range() +func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { + var network, broadcast netip.Addr + ipRange := netipx.RangeOfPrefix(na) network = ipRange.From() broadcast = ipRange.To() return network, broadcast } -func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) { +func (h *Headscale) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) { usedIps, err := h.getUsedIPs() if err != nil { return nil, err @@ -181,7 +182,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro fallthrough case usedIps.Contains(ip): fallthrough - case ip.IsZero() || ip.IsLoopback(): + case ip == netip.Addr{} || ip.IsLoopback(): ip = ip.Next() continue @@ -192,19 +193,19 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro } } -func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) { +func (h *Headscale) getUsedIPs() (*netipx.IPSet, error) { // FIXME: This really deserves a better data model, // but this was quick to get running and it should be enough // to begin experimenting with a dual stack tailnet. var addressesSlices []string h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) - var ips netaddr.IPSetBuilder + var ips netipx.IPSetBuilder for _, slice := range addressesSlices { var machineAddresses MachineAddresses err := machineAddresses.Scan(slice) if err != nil { - return &netaddr.IPSet{}, fmt.Errorf( + return &netipx.IPSet{}, fmt.Errorf( "failed to read ip from database: %w", err, ) @@ -217,7 +218,7 @@ func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) { ipSet, err := ips.IPSet() if err != nil { - return &netaddr.IPSet{}, fmt.Errorf( + return &netipx.IPSet{}, fmt.Errorf( "failed to build IP Set: %w", err, ) @@ -250,7 +251,7 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { return d.DialContext(ctx, "unix", addr) } -func ipPrefixToString(prefixes []netaddr.IPPrefix) []string { +func ipPrefixToString(prefixes []netip.Prefix) []string { result := make([]string, len(prefixes)) for index, prefix := range prefixes { @@ -260,13 +261,13 @@ func ipPrefixToString(prefixes []netaddr.IPPrefix) []string { return result } -func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) { - result := make([]netaddr.IPPrefix, len(prefixes)) +func stringToIPPrefix(prefixes []string) ([]netip.Prefix, error) { + result := make([]netip.Prefix, len(prefixes)) for index, prefixStr := range prefixes { - prefix, err := netaddr.ParseIPPrefix(prefixStr) + prefix, err := netip.ParsePrefix(prefixStr) if err != nil { - return []netaddr.IPPrefix{}, err + return []netip.Prefix{}, err } result[index] = prefix @@ -275,7 +276,7 @@ func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) { return result, nil } -func contains[T string | netaddr.IPPrefix](ts []T, t T) bool { +func contains[T string | netip.Prefix](ts []T, t T) bool { for _, v := range ts { if reflect.DeepEqual(v, t) { return true diff --git a/utils_test.go b/utils_test.go index 07fa62d..13f9f0b 100644 --- a/utils_test.go +++ b/utils_test.go @@ -1,6 +1,9 @@ package headscale import ( + "net/netip" + + "go4.org/netipx" "gopkg.in/check.v1" "inet.af/netaddr" ) @@ -10,7 +13,7 @@ func (s *Suite) TestGetAvailableIp(c *check.C) { c.Assert(err, check.IsNil) - expected := netaddr.MustParseIP("10.27.0.1") + expected := netip.MustParseAddr("10.27.0.1") c.Assert(len(ips), check.Equals, 1) c.Assert(ips[0].String(), check.Equals, expected.String()) @@ -46,8 +49,8 @@ func (s *Suite) TestGetUsedIps(c *check.C) { c.Assert(err, check.IsNil) - expected := netaddr.MustParseIP("10.27.0.1") - expectedIPSetBuilder := netaddr.IPSetBuilder{} + expected := netip.MustParseAddr("10.27.0.1") + expectedIPSetBuilder := netipx.IPSetBuilder{} expectedIPSetBuilder.Add(expected) expectedIPSet, _ := expectedIPSetBuilder.IPSet() @@ -96,11 +99,11 @@ func (s *Suite) TestGetMultiIp(c *check.C) { usedIps, err := app.getUsedIPs() c.Assert(err, check.IsNil) - expected0 := netaddr.MustParseIP("10.27.0.1") - expected9 := netaddr.MustParseIP("10.27.0.10") - expected300 := netaddr.MustParseIP("10.27.0.45") + expected0 := netip.MustParseAddr("10.27.0.1") + expected9 := netip.MustParseAddr("10.27.0.10") + expected300 := netip.MustParseAddr("10.27.0.45") - notExpectedIPSetBuilder := netaddr.IPSetBuilder{} + notExpectedIPSetBuilder := netipx.IPSetBuilder{} notExpectedIPSetBuilder.Add(expected0) notExpectedIPSetBuilder.Add(expected9) notExpectedIPSetBuilder.Add(expected300) @@ -121,7 +124,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { c.Assert( machine1.IPAddresses[0], check.Equals, - netaddr.MustParseIP("10.27.0.1"), + netip.MustParseAddr("10.27.0.1"), ) machine50, err := app.GetMachineByID(50) @@ -130,10 +133,10 @@ func (s *Suite) TestGetMultiIp(c *check.C) { c.Assert( machine50.IPAddresses[0], check.Equals, - netaddr.MustParseIP("10.27.0.50"), + netip.MustParseAddr("10.27.0.50"), ) - expectedNextIP := netaddr.MustParseIP("10.27.1.95") + expectedNextIP := netip.MustParseAddr("10.27.1.95") nextIP, err := app.getAvailableIPs() c.Assert(err, check.IsNil)