diff --git a/app_test.go b/app_test.go index ff3755e..5e53f1c 100644 --- a/app_test.go +++ b/app_test.go @@ -38,7 +38,7 @@ func (s *Suite) ResetDB(c *check.C) { c.Fatal(err) } cfg := Config{ - IPPrefix: netaddr.MustParseIPPrefix("127.0.0.1/32"), + IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"), } h = Headscale{ diff --git a/cli_test.go b/cli_test.go index 9616b4a..528a115 100644 --- a/cli_test.go +++ b/cli_test.go @@ -15,6 +15,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: n.ID, + IPAddress: "10.0.0.1", } h.db.Save(&m) diff --git a/utils.go b/utils.go index 1da2508..404e382 100644 --- a/utils.go +++ b/utils.go @@ -7,18 +7,11 @@ package headscale import ( "crypto/rand" - "encoding/binary" "encoding/json" - "errors" "fmt" "io" - "net" - "time" - - mathrand "math/rand" "golang.org/x/crypto/nacl/box" - "gorm.io/gorm" "inet.af/netaddr" "tailscale.com/types/wgkey" ) @@ -78,47 +71,73 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err return msg, nil } -func (h *Headscale) getAvailableIP() (*net.IP, error) { - i := 0 +func (h *Headscale) getAvailableIP() (*netaddr.IP, error) { + ipPrefix := h.cfg.IPPrefix + + usedIps, err := h.getUsedIPs() + if err != nil { + return nil, err + } + + // for _, ip := range usedIps { + // nextIP := ip.Next() + + // if !containsIPs(usedIps, nextIP) && ipPrefix.Contains(nextIP) { + // return &nextIP, nil + // } + // } + + // // If there are no IPs in use, we are starting fresh and + // // can issue IPs from the beginning of the prefix. + // ip := ipPrefix.IP() + // return &ip, nil + + // return nil, fmt.Errorf("failed to find any available IP in %s", ipPrefix) + + // Get the first IP in our prefix + ip := ipPrefix.IP() + for { - ip, err := getRandomIP(h.cfg.IPPrefix) + if !ipPrefix.Contains(ip) { + return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix) + } + + if ip.IsZero() && + ip.IsLoopback() { + continue + } + + if !containsIPs(usedIps, ip) { + return &ip, nil + } + + ip = ip.Next() + } +} + +func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) { + var addresses []string + h.db.Model(&Machine{}).Pluck("ip_address", &addresses) + + ips := make([]netaddr.IP, len(addresses)) + for index, addr := range addresses { + ip, err := netaddr.ParseIP(addr) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse ip from database, %w", err) } - m := Machine{} - if result := h.db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) { - return ip, nil - } - i++ - if i == 100 { // really random number - break - } - } - return nil, errors.New(fmt.Sprintf("Could not find an available IP address in %s", h.cfg.IPPrefix.String())) -} -func getRandomIP(ipPrefix netaddr.IPPrefix) (*net.IP, error) { - mathrand.Seed(time.Now().Unix()) - ipo, ipnet, err := net.ParseCIDR(ipPrefix.String()) - if err == nil { - ip := ipo.To4() - // fmt.Println("In Randomize IPAddr: IP ", ip, " IPNET: ", ipnet) - // fmt.Println("Final address is ", ip) - // fmt.Println("Broadcast address is ", ipb) - // fmt.Println("Network address is ", ipn) - r := mathrand.Uint32() - ipRaw := make([]byte, 4) - binary.LittleEndian.PutUint32(ipRaw, r) - // ipRaw[3] = 254 - // fmt.Println("ipRaw is ", ipRaw) - for i, v := range ipRaw { - // fmt.Println("IP Before: ", ip[i], " v is ", v, " Mask is: ", ipnet.Mask[i]) - ip[i] = ip[i] + (v &^ ipnet.Mask[i]) - // fmt.Println("IP After: ", ip[i]) - } - // fmt.Println("FINAL IP: ", ip.String()) - return &ip, nil + ips[index] = ip } - return nil, err + return ips, nil +} + +func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool { + for _, v := range ips { + if v == ip { + return true + } + } + + return false } diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..471b822 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,105 @@ +package headscale + +import ( + "gopkg.in/check.v1" + "inet.af/netaddr" +) + +func (s *Suite) TestGetAvailableIp(c *check.C) { + ip, err := h.getAvailableIP() + + c.Assert(err, check.IsNil) + + expected := netaddr.MustParseIP("10.27.0.0") + + c.Assert(ip.String(), check.Equals, expected.String()) +} + +func (s *Suite) TestGetUsedIps(c *check.C) { + ip, err := h.getAvailableIP() + c.Assert(err, check.IsNil) + + n, err := h.CreateNamespace("test_ip") + c.Assert(err, check.IsNil) + + pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = h.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + m := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Name: "testmachine", + NamespaceID: n.ID, + Registered: true, + RegisterMethod: "authKey", + AuthKeyID: uint(pak.ID), + IPAddress: ip.String(), + } + h.db.Save(&m) + + ips, err := h.getUsedIPs() + + c.Assert(err, check.IsNil) + + expected := netaddr.MustParseIP("10.27.0.0") + + c.Assert(ips[0], check.Equals, expected) +} + +func (s *Suite) TestGetMultiIp(c *check.C) { + n, err := h.CreateNamespace("test-ip-multi") + c.Assert(err, check.IsNil) + + for i := 1; i <= 350; i++ { + ip, err := h.getAvailableIP() + c.Assert(err, check.IsNil) + + pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = h.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + m := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Name: "testmachine", + NamespaceID: n.ID, + Registered: true, + RegisterMethod: "authKey", + AuthKeyID: uint(pak.ID), + IPAddress: ip.String(), + } + h.db.Save(&m) + } + + ips, err := h.getUsedIPs() + + c.Assert(err, check.IsNil) + + c.Assert(len(ips), check.Equals, 350) + + c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.0")) + c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.9")) + c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.44")) + + expectedNextIP := netaddr.MustParseIP("10.27.1.94") + nextIP, err := h.getAvailableIP() + c.Assert(err, check.IsNil) + + c.Assert(nextIP.String(), check.Equals, expectedNextIP.String()) + + // If we call get Available again, we should receive + // the same IP, as it has not been reserved. + nextIP2, err := h.getAvailableIP() + c.Assert(err, check.IsNil) + + c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String()) +}