diff --git a/acls.go b/acls.go index 4017e28..5e191e6 100644 --- a/acls.go +++ b/acls.go @@ -185,7 +185,7 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) { return nil, errInvalidNamespace } for _, node := range nodes { - ips = append(ips, node.IPAddress) + ips = append(ips, node.IPAddresses.ToStringSlice()...) } } @@ -219,7 +219,7 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) { // FIXME: Check TagOwners allows this for _, t := range hostinfo.RequestTags { if alias[4:] == t { - ips = append(ips, machine.IPAddress) + ips = append(ips, machine.IPAddresses.ToStringSlice()...) break } @@ -238,7 +238,7 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) { } ips := []string{} for _, n := range nodes { - ips = append(ips, n.IPAddress) + ips = append(ips, n.IPAddresses.ToStringSlice()...) } return ips, nil diff --git a/acls_test.go b/acls_test.go index 629ce1d..c35f4f8 100644 --- a/acls_test.go +++ b/acls_test.go @@ -61,9 +61,9 @@ func (s *Suite) TestPortRange(c *check.C) { c.Assert(rules, check.NotNil) c.Assert(rules, check.HasLen, 1) - c.Assert((rules)[0].DstPorts, check.HasLen, 1) - c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(5400)) - c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(5500)) + c.Assert(rules[0].DstPorts, check.HasLen, 1) + c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(5400)) + c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(5500)) } func (s *Suite) TestPortWildcard(c *check.C) { @@ -75,11 +75,11 @@ func (s *Suite) TestPortWildcard(c *check.C) { c.Assert(rules, check.NotNil) c.Assert(rules, check.HasLen, 1) - c.Assert((rules)[0].DstPorts, check.HasLen, 1) - c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) - c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) - c.Assert((rules)[0].SrcIPs, check.HasLen, 1) - c.Assert((rules)[0].SrcIPs[0], check.Equals, "*") + c.Assert(rules[0].DstPorts, check.HasLen, 1) + c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) + c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, "*") } func (s *Suite) TestPortNamespace(c *check.C) { @@ -91,7 +91,7 @@ func (s *Suite) TestPortNamespace(c *check.C) { _, err = app.GetMachine("testnamespace", "testmachine") c.Assert(err, check.NotNil) - ip, _ := app.getAvailableIP() + ips, _ := app.getAvailableIPs() machine := Machine{ ID: 0, MachineKey: "foo", @@ -101,7 +101,7 @@ func (s *Suite) TestPortNamespace(c *check.C) { NamespaceID: namespace.ID, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: ip.String(), + IPAddresses: ips, AuthKeyID: uint(pak.ID), } app.db.Save(&machine) @@ -116,12 +116,13 @@ func (s *Suite) TestPortNamespace(c *check.C) { c.Assert(rules, check.NotNil) c.Assert(rules, check.HasLen, 1) - c.Assert((rules)[0].DstPorts, check.HasLen, 1) - c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) - c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) - c.Assert((rules)[0].SrcIPs, check.HasLen, 1) - c.Assert((rules)[0].SrcIPs[0], check.Not(check.Equals), "not an ip") - c.Assert((rules)[0].SrcIPs[0], check.Equals, ip.String()) + c.Assert(rules[0].DstPorts, check.HasLen, 1) + c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) + c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") + c.Assert(len(ips), check.Equals, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()) } func (s *Suite) TestPortGroup(c *check.C) { @@ -133,7 +134,7 @@ func (s *Suite) TestPortGroup(c *check.C) { _, err = app.GetMachine("testnamespace", "testmachine") c.Assert(err, check.NotNil) - ip, _ := app.getAvailableIP() + ips, _ := app.getAvailableIPs() machine := Machine{ ID: 0, MachineKey: "foo", @@ -143,7 +144,7 @@ func (s *Suite) TestPortGroup(c *check.C) { NamespaceID: namespace.ID, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: ip.String(), + IPAddresses: ips, AuthKeyID: uint(pak.ID), } app.db.Save(&machine) @@ -156,10 +157,11 @@ func (s *Suite) TestPortGroup(c *check.C) { c.Assert(rules, check.NotNil) c.Assert(rules, check.HasLen, 1) - c.Assert((rules)[0].DstPorts, check.HasLen, 1) - c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) - c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) - c.Assert((rules)[0].SrcIPs, check.HasLen, 1) - c.Assert((rules)[0].SrcIPs[0], check.Not(check.Equals), "not an ip") - c.Assert((rules)[0].SrcIPs[0], check.Equals, ip.String()) + c.Assert(rules[0].DstPorts, check.HasLen, 1) + c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) + c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") + c.Assert(len(ips), check.Equals, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()) } diff --git a/api.go b/api.go index d3bd572..020ded0 100644 --- a/api.go +++ b/api.go @@ -497,6 +497,7 @@ func (h *Headscale) handleMachineRegistrationNew( ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) } +// TODO: check if any locks are needed around IP allocation. func (h *Headscale) handleAuthKey( ctx *gin.Context, machineKey key.MachinePublic, @@ -554,14 +555,14 @@ func (h *Headscale) handleAuthKey( log.Debug(). Str("func", "handleAuthKey"). Str("machine", machine.Name). - Msg("Authentication key was valid, proceeding to acquire an IP address") - ip, err := h.getAvailableIP() + Msg("Authentication key was valid, proceeding to acquire IP addresses") + ips, err := h.getAvailableIPs() if err != nil { log.Error(). Caller(). Str("func", "handleAuthKey"). Str("machine", machine.Name). - Msg("Failed to find an available IP") + Msg("Failed to find an available IP address") machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). Inc() @@ -570,12 +571,12 @@ func (h *Headscale) handleAuthKey( log.Info(). Str("func", "handleAuthKey"). Str("machine", machine.Name). - Str("ip", ip.String()). - Msgf("Assigning %s to %s", ip, machine.Name) + Str("ips", strings.Join(ips.ToStringSlice(), ",")). + Msgf("Assigning %s to %s", strings.Join(ips.ToStringSlice(), ","), machine.Name) machine.Expiry = ®isterRequest.Expiry machine.AuthKeyID = uint(pak.ID) - machine.IPAddress = ip.String() + machine.IPAddresses = ips machine.NamespaceID = pak.NamespaceID machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) @@ -610,6 +611,6 @@ func (h *Headscale) handleAuthKey( log.Info(). Str("func", "handleAuthKey"). Str("machine", machine.Name). - Str("ip", machine.IPAddress). + Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")). Msg("Successfully authenticated via AuthKey") } diff --git a/app.go b/app.go index b9d570c..b7dc8c6 100644 --- a/app.go +++ b/app.go @@ -68,7 +68,7 @@ type Config struct { ServerURL string Addr string EphemeralNodeInactivityTimeout time.Duration - IPPrefix netaddr.IPPrefix + IPPrefixes []netaddr.IPPrefix PrivateKeyPath string BaseDomain string @@ -197,9 +197,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) { } if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS - magicDNSDomains := generateMagicDNSRootDomains( - app.cfg.IPPrefix, - ) + magicDNSDomains := generateMagicDNSRootDomains(app.cfg.IPPrefixes) // we might have routes already from Split DNS if app.cfg.DNSConfig.Routes == nil { app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver) diff --git a/app_test.go b/app_test.go index bff1393..02fdce8 100644 --- a/app_test.go +++ b/app_test.go @@ -41,7 +41,9 @@ func (s *Suite) ResetDB(c *check.C) { c.Fatal(err) } cfg := Config{ - IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"), + IPPrefixes: []netaddr.IPPrefix{ + netaddr.MustParseIPPrefix("10.27.0.0/23"), + }, } app = Headscale{ diff --git a/cli_test.go b/cli_test.go index ef7e299..71f2bea 100644 --- a/cli_test.go +++ b/cli_test.go @@ -4,6 +4,7 @@ import ( "time" "gopkg.in/check.v1" + "inet.af/netaddr" ) func (s *Suite) TestRegisterMachine(c *check.C) { @@ -19,16 +20,17 @@ func (s *Suite) TestRegisterMachine(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: namespace.ID, - IPAddress: "10.0.0.1", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("10.0.0.1")}, Expiry: &now, } - app.db.Save(&machine) + err = app.db.Save(&machine).Error + c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.GetMachine(namespace.Name, machine.Name) c.Assert(err, check.IsNil) machineAfterRegistering, err := app.RegisterMachine( - "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", + machine.MachineKey, namespace.Name, ) c.Assert(err, check.IsNil) diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 26ead6d..d6b86ee 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "strconv" + "strings" "time" survey "github.com/AlecAivazis/survey/v2" @@ -459,7 +460,7 @@ func nodesToPtables( "Name", "NodeKey", "Namespace", - "IP address", + "IP addresses", "Ephemeral", "Last seen", "Online", @@ -523,7 +524,7 @@ func nodesToPtables( machine.Name, nodeKey.ShortString(), namespace, - machine.IpAddress, + strings.Join(machine.IpAddresses, ", "), strconv.FormatBool(ephemeral), lastSeenTime, online, diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index a485664..2c8f2a6 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -41,7 +41,7 @@ func LoadConfig(path string) error { viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01") - viper.SetDefault("ip_prefix", "100.64.0.0/10") + viper.SetDefault("ip_prefixes", []string{"100.64.0.0/10"}) viper.SetDefault("log_level", "info") @@ -221,10 +221,20 @@ func getHeadscaleConfig() headscale.Config { dnsConfig, baseDomain := GetDNSConfig() derpConfig := GetDERPConfig() + configuredPrefixes := viper.GetStringSlice("ip_prefixes") + prefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes)) + for i, prefixInConfig := range configuredPrefixes { + prefix, err := netaddr.ParseIPPrefix(prefixInConfig) + if err != nil { + panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err)) + } + prefixes = append(prefixes, prefix) + } + return headscale.Config{ ServerURL: viper.GetString("server_url"), Addr: viper.GetString("listen_addr"), - IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")), + IPPrefixes: prefixes, PrivateKeyPath: absPath(viper.GetString("private_key_path")), BaseDomain: baseDomain, diff --git a/dns.go b/dns.go index af6f989..5a4f682 100644 --- a/dns.go +++ b/dns.go @@ -34,14 +34,25 @@ const ( // From the netmask we can find out the wildcard bits (the bits that are not set in the netmask). // This allows us to then calculate the subnets included in the subsequent class block and generate the entries. -func generateMagicDNSRootDomains( - ipPrefix netaddr.IPPrefix, -) []dnsname.FQDN { - // TODO(juanfont): we are not handing out IPv6 addresses yet - // and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network) - ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.") - fqdns := []dnsname.FQDN{ipv6base} +func generateMagicDNSRootDomains(ipPrefixes []netaddr.IPPrefix) []dnsname.FQDN { + fqdns := make([]dnsname.FQDN, 0, len(ipPrefixes)) + for _, ipPrefix := range ipPrefixes { + var generateDnsRoot func(netaddr.IPPrefix) []dnsname.FQDN + switch ipPrefix.IP().BitLen() { + case 32: + generateDnsRoot = generateIPv4DNSRootDomain + default: + panic(fmt.Sprintf("unsupported IP version with address length %d", ipPrefix.IP().BitLen())) + } + + fqdns = append(fqdns, generateDnsRoot(ipPrefix)...) + } + + return fqdns +} + +func generateIPv4DNSRootDomain(ipPrefix netaddr.IPPrefix) (fqdns []dnsname.FQDN) { // Conversion to the std lib net.IPnet, a bit easier to operate netRange := ipPrefix.IPNet() maskBits, _ := netRange.Mask.Size() @@ -73,7 +84,7 @@ func generateMagicDNSRootDomains( fqdns = append(fqdns, fqdn) } - return fqdns + return } func getMapResponseDNSConfig( diff --git a/dns_test.go b/dns_test.go index 92f7476..8c0da63 100644 --- a/dns_test.go +++ b/dns_test.go @@ -124,7 +124,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Namespace: *namespaceShared1, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.1", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), } app.db.Save(machineInShared1) @@ -142,7 +142,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Namespace: *namespaceShared2, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.2", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), } app.db.Save(machineInShared2) @@ -160,7 +160,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Namespace: *namespaceShared3, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.3", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), } app.db.Save(machineInShared3) @@ -178,7 +178,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Namespace: *namespaceShared1, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.4", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")}, AuthKeyID: uint(PreAuthKey2InShared1.ID), } app.db.Save(machine2InShared1) @@ -273,7 +273,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Namespace: *namespaceShared1, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.1", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), } app.db.Save(machineInShared1) @@ -291,7 +291,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Namespace: *namespaceShared2, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.2", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), } app.db.Save(machineInShared2) @@ -309,7 +309,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Namespace: *namespaceShared3, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.3", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), } app.db.Save(machineInShared3) @@ -327,7 +327,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Namespace: *namespaceShared1, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.4", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")}, AuthKeyID: uint(preAuthKey2InShared1.ID), } app.db.Save(machine2InShared1) diff --git a/integration_test.go b/integration_test.go index fc71ed1..ade85bf 100644 --- a/integration_test.go +++ b/integration_test.go @@ -372,70 +372,74 @@ func (s *IntegrationTestSuite) TestListNodes() { func (s *IntegrationTestSuite) TestGetIpAddresses() { for _, scales := range s.namespaces { - ipPrefix := netaddr.MustParseIPPrefix("100.64.0.0/10") ips, err := getIPs(scales.tailscales) assert.Nil(s.T(), err) - for hostname := range scales.tailscales { - s.T().Run(hostname, func(t *testing.T) { - ip, ok := ips[hostname] + for hostname, _ := range scales.tailscales { + ips := ips[hostname] + for _, ip := range ips { + s.T().Run(hostname, func(t *testing.T) { + assert.NotNil(t, ip) - assert.True(t, ok) - assert.NotNil(t, ip) + fmt.Printf("IP for %s: %s\n", hostname, ip) - fmt.Printf("IP for %s: %s\n", hostname, ip) - - // c.Assert(ip.Valid(), check.IsTrue) - assert.True(t, ip.Is4()) - assert.True(t, ipPrefix.Contains(ip)) - }) + // c.Assert(ip.Valid(), check.IsTrue) + assert.True(t, ip.Is4() || ip.Is6()) + switch { + case ip.Is4(): + assert.True(t, IpPrefix4.Contains(ip)) + case ip.Is6(): + assert.True(t, IpPrefix6.Contains(ip)) + } + }) + } } } } // TODO(kradalby): fix this test -// We need some way to impot ipnstate.Status from multiple go packages. +// We need some way to import ipnstate.Status from multiple go packages. // Currently it will only work with 1.18.x since that is the last // version we have in go.mod // func (s *IntegrationTestSuite) TestStatus() { -// for _, scales := range s.namespaces { -// ips, err := getIPs(scales.tailscales) -// assert.Nil(s.T(), err) +// for _, scales := range s.namespaces { +// ips, err := getIPs(scales.tailscales) +// assert.Nil(s.T(), err) // -// for hostname, tailscale := range scales.tailscales { -// s.T().Run(hostname, func(t *testing.T) { -// command := []string{"tailscale", "status", "--json"} +// for hostname, tailscale := range scales.tailscales { +// s.T().Run(hostname, func(t *testing.T) { +// command := []string{"tailscale", "status", "--json"} // -// fmt.Printf("Getting status for %s\n", hostname) -// result, err := ExecuteCommand( -// &tailscale, -// command, -// []string{}, -// ) -// assert.Nil(t, err) +// fmt.Printf("Getting status for %s\n", hostname) +// result, err := ExecuteCommand( +// &tailscale, +// command, +// []string{}, +// ) +// assert.Nil(t, err) // -// var status ipnstate.Status -// err = json.Unmarshal([]byte(result), &status) -// assert.Nil(s.T(), err) +// var status ipnstate.Status +// err = json.Unmarshal([]byte(result), &status) +// assert.Nil(s.T(), err) // -// // TODO(kradalby): Replace this check with peer length of SAME namespace -// // Check if we have as many nodes in status -// // as we have IPs/tailscales -// // lines := strings.Split(result, "\n") -// // assert.Equal(t, len(ips), len(lines)-1) -// // assert.Equal(t, len(scales.tailscales), len(lines)-1) +// // TODO(kradalby): Replace this check with peer length of SAME namespace +// // Check if we have as many nodes in status +// // as we have IPs/tailscales +// // lines := strings.Split(result, "\n") +// // assert.Equal(t, len(ips), len(lines)-1) +// // assert.Equal(t, len(scales.tailscales), len(lines)-1) // -// peerIps := getIPsfromIPNstate(status) +// peerIps := getIPsfromIPNstate(status) // -// // Check that all hosts is present in all hosts status -// for ipHostname, ip := range ips { -// if hostname != ipHostname { -// assert.Contains(t, peerIps, ip) -// } -// } -// }) -// } -// } +// // Check that all hosts is present in all hosts status +// for ipHostname, ip := range ips { +// if hostname != ipHostname { +// assert.Contains(t, peerIps, ip) +// } +// } +// }) +// } +// } // } func getIPsfromIPNstate(status ipnstate.Status) []netaddr.IP { @@ -448,16 +452,19 @@ func getIPsfromIPNstate(status ipnstate.Status) []netaddr.IP { return ips } -func (s *IntegrationTestSuite) TestPingAllPeers() { +func (s *IntegrationTestSuite) TestPingAllPeersByAddress() { for _, scales := range s.namespaces { ips, err := getIPs(scales.tailscales) assert.Nil(s.T(), err) for hostname, tailscale := range scales.tailscales { - for peername, ip := range ips { - s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { + for peername, peerIPs := range ips { + for i, ip := range peerIPs { // We currently cant ping ourselves, so skip that. - if peername != hostname { + if peername == hostname { + continue + } + s.T().Run(fmt.Sprintf("%s-%s-%d", hostname, peername, i), func(t *testing.T) { // We are only interested in "direct ping" which means what we // might need a couple of more attempts before reaching the node. command := []string{ @@ -469,9 +476,8 @@ func (s *IntegrationTestSuite) TestPingAllPeers() { } fmt.Printf( - "Pinging from %s (%s) to %s (%s)\n", + "Pinging from %s to %s (%s)\n", hostname, - ips[hostname], peername, ip, ) @@ -483,8 +489,8 @@ func (s *IntegrationTestSuite) TestPingAllPeers() { assert.Nil(t, err) fmt.Printf("Result for %s: %s\n", hostname, result) assert.Contains(t, result, "pong") - } - }) + }) + } } } } @@ -553,17 +559,17 @@ func (s *IntegrationTestSuite) TestSharedNodes() { // TODO(juanfont): We have to find out why do we need to wait time.Sleep(100 * time.Second) // Wait for the nodes to receive updates - mainIps, err := getIPs(main.tailscales) - assert.Nil(s.T(), err) - sharedIps, err := getIPs(shared.tailscales) assert.Nil(s.T(), err) for hostname, tailscale := range main.tailscales { - for peername, ip := range sharedIps { - s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { + for peername, peerIPs := range sharedIps { + for i, ip := range peerIPs { // We currently cant ping ourselves, so skip that. - if peername != hostname { + if peername == hostname { + continue + } + s.T().Run(fmt.Sprintf("%s-%s-%d", hostname, peername, i), func(t *testing.T) { // We are only interested in "direct ping" which means what we // might need a couple of more attempts before reaching the node. command := []string{ @@ -575,9 +581,8 @@ func (s *IntegrationTestSuite) TestSharedNodes() { } fmt.Printf( - "Pinging from %s (%s) to %s (%s)\n", + "Pinging from %s to %s (%s)\n", hostname, - mainIps[hostname], peername, ip, ) @@ -589,8 +594,8 @@ func (s *IntegrationTestSuite) TestSharedNodes() { assert.Nil(t, err) fmt.Printf("Result for %s: %s\n", hostname, result) assert.Contains(t, result, "pong") - } - }) + }) + } } } } @@ -607,7 +612,7 @@ func (s *IntegrationTestSuite) TestTailDrop() { _, err := ExecuteCommand( &tailscale, command, - []string{}, + []string{"GOMAXPROCS=32"}, ) assert.Nil(s.T(), err) for peername, ip := range ips { @@ -653,7 +658,7 @@ func (s *IntegrationTestSuite) TestTailDrop() { _, err = ExecuteCommand( &tailscale, command, - []string{"ALL_PROXY=socks5://localhost:1055"}, + []string{"ALL_PROXY=socks5://localhost:1055", "GOMAXPROCS=32"}, ) if err == nil { break @@ -684,78 +689,125 @@ func (s *IntegrationTestSuite) TestTailDrop() { ) assert.Nil(s.T(), err) for peername, ip := range ips { + if peername == hostname { + continue + } s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { - if peername != hostname { - command := []string{ - "ls", - fmt.Sprintf("/tmp/file_from_%s", peername), - } - fmt.Printf( - "Checking file in %s (%s) from %s (%s)\n", - hostname, - ips[hostname], - peername, - ip, - ) - result, err := ExecuteCommand( - &tailscale, - command, - []string{}, - ) - assert.Nil(t, err) - fmt.Printf("Result for %s: %s\n", peername, result) - assert.Equal( - t, - result, - fmt.Sprintf("/tmp/file_from_%s\n", peername), - ) + command := []string{ + "ls", + fmt.Sprintf("/tmp/file_from_%s", peername), } + fmt.Printf( + "Checking file in %s (%s) from %s (%s)\n", + hostname, + ips[hostname], + peername, + ip, + ) + result, err := ExecuteCommand( + &tailscale, + command, + []string{}, + ) + assert.Nil(t, err) + fmt.Printf("Result for %s: %s\n", peername, result) + assert.Equal( + t, + fmt.Sprintf("/tmp/file_from_%s\n", peername), + result, + ) }) } } } } -func (s *IntegrationTestSuite) TestMagicDNS() { +func (s *IntegrationTestSuite) TestPingAllPeersByHostname() { for namespace, scales := range s.namespaces { ips, err := getIPs(scales.tailscales) assert.Nil(s.T(), err) for hostname, tailscale := range scales.tailscales { - for peername, ip := range ips { + for peername, _ := range ips { + if peername == hostname { + continue + } s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { - if peername != hostname { - command := []string{ - "tailscale", "ping", - "--timeout=10s", - "--c=20", - "--until-direct=true", - fmt.Sprintf("%s.%s.headscale.net", peername, namespace), - } - - fmt.Printf( - "Pinging using Hostname (magicdns) from %s (%s) to %s (%s)\n", - hostname, - ips[hostname], - peername, - ip, - ) - result, err := ExecuteCommand( - &tailscale, - command, - []string{}, - ) - assert.Nil(t, err) - fmt.Printf("Result for %s: %s\n", hostname, result) - assert.Contains(t, result, "pong") + command := []string{ + "tailscale", "ping", + "--timeout=10s", + "--c=20", + "--until-direct=true", + fmt.Sprintf("%s.%s.headscale.net", peername, namespace), } + + fmt.Printf( + "Pinging using Hostname from %s to %s\n", + hostname, + peername, + ) + result, err := ExecuteCommand( + &tailscale, + command, + []string{}, + ) + assert.Nil(t, err) + fmt.Printf("Result for %s: %s\n", hostname, result) + assert.Contains(t, result, "pong") }) } } } } -func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, error) { - ips := make(map[string]netaddr.IP) +// TODO: +// * With manual testing, MagicDNS does not respond to AAAA queries. Why? +// * Tailscaled only adds a route to the IPv4 (100.100.100.100) address of the MagicDNS service, +// event though there is an IPv6 one (fd7a:115c:a1e0::53) as well. +func (s *IntegrationTestSuite) TestMagicDNSv4() { + for namespace, scales := range s.namespaces { + ips, err := getIPs(scales.tailscales) + assert.Nil(s.T(), err) + for hostname, tailscale := range scales.tailscales { + for peername, ips := range ips { + if peername == hostname { + continue + } + s.T().Run(fmt.Sprintf("%s-%s-ipv4", hostname, peername), func(t *testing.T) { + command := []string{ + "host", "-4", "-t", "A", + fmt.Sprintf("%s.%s.headscale.net", peername, namespace), + "100.100.100.100", + } + + fmt.Printf( + "Resolving name %s (IPv4) from %s over IPv4\n", + peername, + hostname, + ) + result, err := ExecuteCommand( + &tailscale, + command, + []string{}, + ) + assert.Nil(t, err) + fmt.Printf("Result for %s: %s\n", hostname, result) + + resolved := false + for _, ip := range ips { + if strings.Contains(result, fmt.Sprintf("has address %s", ip.String())) { + resolved = true + break + } + } + assert.Equal(t, true, resolved) + }) + } + } + } +} + +func getIPs(tailscales map[string]dockertest.Resource) (map[string][]netaddr.IP, error) { + ips := make(map[string][]netaddr.IP) for hostname, tailscale := range tailscales { command := []string{"tailscale", "ip"} @@ -768,12 +820,17 @@ func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, e return nil, err } - ip, err := netaddr.ParseIP(strings.TrimSuffix(result, "\n")) - if err != nil { - return nil, err + for _, address := range strings.Split(result, "\n") { + address = strings.TrimSuffix(address, "\n") + if len(address) < 1 { + continue + } + ip, err := netaddr.ParseIP(address) + if err != nil { + return nil, err + } + ips[hostname] = append(ips[hostname], ip) } - - ips[hostname] = ip } return ips, nil diff --git a/machine.go b/machine.go index c13d79d..d94fb00 100644 --- a/machine.go +++ b/machine.go @@ -1,6 +1,7 @@ package headscale import ( + "database/sql/driver" "encoding/json" "errors" "fmt" @@ -23,6 +24,7 @@ const ( errMachineNotFound = Error("machine not found") errMachineAlreadyRegistered = Error("machine already registered") errMachineRouteIsNotAvailable = Error("route is not available on machine") + errMachineAddressesInvalid = Error("failed to parse machine addresses") ) // Machine is a Headscale client. @@ -31,7 +33,7 @@ type Machine struct { MachineKey string `gorm:"type:varchar(64);unique_index"` NodeKey string DiscoKey string - IPAddress string + IPAddresses MachineAddresses Name string NamespaceID uint Namespace Namespace `gorm:"foreignKey:NamespaceID"` @@ -64,6 +66,47 @@ func (machine Machine) isRegistered() bool { return machine.Registered } +type MachineAddresses []netaddr.IP + +func (ma MachineAddresses) ToStringSlice() []string { + strSlice := make([]string, 0, len(ma)) + for _, addr := range ma { + strSlice = append(strSlice, addr.String()) + } + + return strSlice +} + +func (ma *MachineAddresses) Scan(destination interface{}) error { + switch value := destination.(type) { + case string: + addresses := strings.Split(value, ",") + *ma = (*ma)[:0] + for _, addr := range addresses { + if len(addr) < 1 { + continue + } + parsed, err := netaddr.ParseIP(addr) + if err != nil { + return err + } + *ma = append(*ma, parsed) + } + + return nil + + default: + return fmt.Errorf("%w: unexpected data type %T", errMachineAddressesInvalid, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (ma MachineAddresses) Value() (driver.Value, error) { + addresses := strings.Join(ma.ToStringSlice(), ",") + + return addresses, nil +} + // isExpired returns whether the machine registration has expired. func (machine Machine) isExpired() bool { // If Expiry is not set, the client has not indicated that @@ -470,22 +513,12 @@ func (machine Machine) toNode( } addrs := []netaddr.IPPrefix{} - nodeAddr, err := netaddr.ParseIP(m.IPAddresses) - if err != nil { - log.Trace(). - Caller(). - Str("ip", machine.IPAddresses). - Msgf("Failed to parse machine IP: %s", machine.IPAddresses) - return nil, err + for _, machineAddress := range machine.IPAddresses { + ip := netaddr.IPPrefixFrom(machineAddress, machineAddress.BitLen()) + addrs = append(addrs, ip) } - ip := netaddr.IPPrefixFrom(nodeAddr, nodeAddr.BitLen()) - addrs = append(addrs, ip) - allowedIPs := []netaddr.IPPrefix{} - allowedIPs = append( - allowedIPs, - ip, - ) // we append the node own IP, as it is required by the clients + allowedIPs := append([]netaddr.IPPrefix{}, addrs...) // we append the node own IP, as it is required by the clients if includeRoutes { routesStr := []string{} @@ -592,11 +625,11 @@ func (machine *Machine) toProto() *v1.Machine { Id: machine.ID, MachineKey: machine.MachineKey, - NodeKey: machine.NodeKey, - DiscoKey: machine.DiscoKey, - IpAddress: machine.IPAddress, - Name: machine.Name, - Namespace: machine.Namespace.toProto(), + NodeKey: machine.NodeKey, + DiscoKey: machine.DiscoKey, + IpAddresses: machine.IPAddresses.ToStringSlice(), + Name: machine.Name, + Namespace: machine.Namespace.toProto(), Registered: machine.Registered, @@ -695,7 +728,7 @@ func (h *Headscale) RegisterMachine( return nil, err } - ip, err := h.getAvailableIP() + ips, err := h.getAvailableIPs() if err != nil { log.Error(). Caller(). @@ -709,10 +742,10 @@ func (h *Headscale) RegisterMachine( log.Trace(). Caller(). Str("machine", machine.Name). - Str("ip", ip.String()). + Str("ip", strings.Join(ips.ToStringSlice(), ",")). Msg("Found IP for host") - machine.IPAddress = ip.String() + machine.IPAddresses = ips machine.NamespaceID = namespace.ID machine.Registered = true machine.RegisterMethod = RegisterMethodCLI @@ -722,7 +755,7 @@ func (h *Headscale) RegisterMachine( log.Trace(). Caller(). Str("machine", machine.Name). - Str("ip", ip.String()). + Str("ip", strings.Join(ips.ToStringSlice(), ",")). Msg("Machine registered with the database") return machine, nil diff --git a/machine_test.go b/machine_test.go index eb09007..ecb50de 100644 --- a/machine_test.go +++ b/machine_test.go @@ -6,6 +6,7 @@ import ( "time" "gopkg.in/check.v1" + "inet.af/netaddr" ) func (s *Suite) TestGetMachine(c *check.C) { @@ -199,3 +200,22 @@ func (s *Suite) TestExpireMachine(c *check.C) { c.Assert(machineFromDB.isExpired(), check.Equals, true) } + +func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { + input := MachineAddresses([]netaddr.IP{ + netaddr.MustParseIP("192.0.2.1"), + netaddr.MustParseIP("2001:db8::1"), + }) + serialized, err := input.Value() + c.Assert(err, check.IsNil) + c.Assert(serialized.(string), check.Equals, "192.0.2.1,2001:db8::1") + + var deserialized MachineAddresses + err = deserialized.Scan(serialized) + c.Assert(err, check.IsNil) + + c.Assert(len(deserialized), check.Equals, len(input)) + for i := range deserialized { + c.Assert(deserialized[i], check.Equals, input[i]) + } +} diff --git a/namespaces_test.go b/namespaces_test.go index 9793e60..d07deb9 100644 --- a/namespaces_test.go +++ b/namespaces_test.go @@ -4,6 +4,7 @@ import ( "github.com/rs/zerolog/log" "gopkg.in/check.v1" "gorm.io/gorm" + "inet.af/netaddr" ) func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) { @@ -146,7 +147,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { Namespace: *namespaceShared1, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.1", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")}, AuthKeyID: uint(preAuthKeyShared1.ID), } app.db.Save(machineInShared1) @@ -164,7 +165,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { Namespace: *namespaceShared2, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.2", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")}, AuthKeyID: uint(preAuthKeyShared2.ID), } app.db.Save(machineInShared2) @@ -182,7 +183,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { Namespace: *namespaceShared3, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.3", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")}, AuthKeyID: uint(preAuthKeyShared3.ID), } app.db.Save(machineInShared3) @@ -200,7 +201,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { Namespace: *namespaceShared1, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.4", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")}, AuthKeyID: uint(preAuthKey2Shared1.ID), } app.db.Save(machine2InShared1) diff --git a/oidc.go b/oidc.go index 120a4cf..a47863f 100644 --- a/oidc.go +++ b/oidc.go @@ -126,6 +126,7 @@ var oidcCallbackTemplate = template.Must( `), ) +// TODO: Why is the entire machine registration logic duplicated here? // OIDCCallback handles the callback from the OIDC endpoint // Retrieves the mkey from the state cache and adds the machine to the users email namespace // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities @@ -316,7 +317,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - ip, err := h.getAvailableIP() + ips, err := h.getAvailableIPs() if err != nil { log.Error(). Caller(). @@ -330,7 +331,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - machine.IPAddress = ip.String() + machine.IPAddresses = ips machine.NamespaceID = namespace.ID machine.Registered = true machine.RegisterMethod = RegisterMethodOIDC diff --git a/sharing_test.go b/sharing_test.go index fd7634d..b7fef4e 100644 --- a/sharing_test.go +++ b/sharing_test.go @@ -2,6 +2,7 @@ package headscale import ( "gopkg.in/check.v1" + "inet.af/netaddr" ) func CreateNodeNamespace( @@ -26,7 +27,7 @@ func CreateNodeNamespace( NamespaceID: namespace.ID, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: ip, + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")}, AuthKeyID: uint(pak1.ID), } app.db.Save(machine) @@ -214,7 +215,7 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) { NamespaceID: namespace1.ID, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.4", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")}, AuthKeyID: uint(pak4.ID), } app.db.Save(machine4) @@ -294,7 +295,7 @@ func (s *Suite) TestDeleteSharedMachine(c *check.C) { NamespaceID: namespace1.ID, Registered: true, RegisterMethod: RegisterMethodAuthKey, - IPAddress: "100.64.0.4", + IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")}, AuthKeyID: uint(pak4n1.ID), } app.db.Save(machine4) diff --git a/utils.go b/utils.go index d2b6b90..3b18a40 100644 --- a/utils.go +++ b/utils.go @@ -133,9 +133,24 @@ func encode( return privKey.SealTo(*pubKey, b), nil } -func (h *Headscale) getAvailableIP() (*netaddr.IP, error) { - ipPrefix := h.cfg.IPPrefix +func (h *Headscale) getAvailableIPs() (ips MachineAddresses, err error) { + ipPrefixes := h.cfg.IPPrefixes + for _, ipPrefix := range ipPrefixes { + var ip *netaddr.IP + ip, err = h.getAvailableIP(ipPrefix) + if err != nil { + return + } + ips = append(ips, *ip) + } + return +} + +// TODO: Is this concurrency safe? +// What would happen if multiple hosts were to register at the same time? +// Would we attempt to assign the same addresses to multiple nodes? +func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) { usedIps, err := h.getUsedIPs() if err != nil { return nil, err @@ -143,6 +158,7 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) { ipPrefixNetworkAddress, ipPrefixBroadcastAddress := func() (netaddr.IP, netaddr.IP) { ipRange := ipPrefix.Range() + return ipRange.From(), ipRange.To() }() @@ -171,19 +187,20 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) { } func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) { - var addresses []string - h.db.Model(&Machine{}).Pluck("ip_address", &addresses) + // FIXME: This really deserves a better data model, + // but this was quick to get running and it should be enough + // to begin experimenting with a dual stack tailnet. + var addressesSlices []string + h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) - ips := make([]netaddr.IP, len(addresses)) - for index, addr := range addresses { - if addr != "" { - ip, err := netaddr.ParseIP(addr) - if err != nil { - return nil, fmt.Errorf("failed to parse ip from database: %w", err) - } - - ips[index] = ip + ips := make([]netaddr.IP, 0, len(h.cfg.IPPrefixes)*len(addressesSlices)) + for _, slice := range addressesSlices { + var a MachineAddresses + err := a.Scan(slice) + if err != nil { + return nil, fmt.Errorf("failed to read ip from database: %w", err) } + ips = append(ips, a...) } return ips, nil diff --git a/utils_test.go b/utils_test.go index 9b0295b..feb44d5 100644 --- a/utils_test.go +++ b/utils_test.go @@ -6,17 +6,18 @@ import ( ) func (s *Suite) TestGetAvailableIp(c *check.C) { - ip, err := app.getAvailableIP() + ips, err := app.getAvailableIPs() c.Assert(err, check.IsNil) expected := netaddr.MustParseIP("10.27.0.1") - c.Assert(ip.String(), check.Equals, expected.String()) + c.Assert(len(ips), check.Equals, 1) + c.Assert(ips[0].String(), check.Equals, expected.String()) } func (s *Suite) TestGetUsedIps(c *check.C) { - ip, err := app.getAvailableIP() + ips, err := app.getAvailableIPs() c.Assert(err, check.IsNil) namespace, err := app.CreateNamespace("test_ip") @@ -38,22 +39,24 @@ func (s *Suite) TestGetUsedIps(c *check.C) { Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - IPAddress: ip.String(), + IPAddresses: ips, } app.db.Save(&machine) - ips, err := app.getUsedIPs() + usedIps, err := app.getUsedIPs() c.Assert(err, check.IsNil) expected := netaddr.MustParseIP("10.27.0.1") - c.Assert(ips[0], check.Equals, expected) + c.Assert(len(usedIps), check.Equals, 1) + c.Assert(usedIps[0], check.Equals, expected) machine1, err := app.GetMachineByID(0) c.Assert(err, check.IsNil) - c.Assert(machine1.IPAddress, check.Equals, expected.String()) + c.Assert(len(machine1.IPAddresses), check.Equals, 1) + c.Assert(machine1.IPAddresses[0], check.Equals, expected) } func (s *Suite) TestGetMultiIp(c *check.C) { @@ -61,7 +64,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { c.Assert(err, check.IsNil) for index := 1; index <= 350; index++ { - ip, err := app.getAvailableIP() + ips, err := app.getAvailableIPs() c.Assert(err, check.IsNil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) @@ -80,59 +83,64 @@ func (s *Suite) TestGetMultiIp(c *check.C) { Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - IPAddress: ip.String(), + IPAddresses: ips, } app.db.Save(&machine) } - ips, err := app.getUsedIPs() + usedIps, err := app.getUsedIPs() c.Assert(err, check.IsNil) - c.Assert(len(ips), check.Equals, 350) + c.Assert(len(usedIps), check.Equals, 350) - c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.1")) - c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.10")) - c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.45")) + c.Assert(usedIps[0], check.Equals, netaddr.MustParseIP("10.27.0.1")) + c.Assert(usedIps[9], check.Equals, netaddr.MustParseIP("10.27.0.10")) + c.Assert(usedIps[300], check.Equals, netaddr.MustParseIP("10.27.1.45")) // Check that we can read back the IPs machine1, err := app.GetMachineByID(1) c.Assert(err, check.IsNil) + c.Assert(len(machine1.IPAddresses), check.Equals, 1) c.Assert( - machine1.IPAddress, + machine1.IPAddresses[0], check.Equals, - netaddr.MustParseIP("10.27.0.1").String(), + netaddr.MustParseIP("10.27.0.1"), ) machine50, err := app.GetMachineByID(50) c.Assert(err, check.IsNil) + c.Assert(len(machine50.IPAddresses), check.Equals, 1) c.Assert( - machine50.IPAddress, + machine50.IPAddresses[0], check.Equals, - netaddr.MustParseIP("10.27.0.50").String(), + netaddr.MustParseIP("10.27.0.50"), ) expectedNextIP := netaddr.MustParseIP("10.27.1.95") - nextIP, err := app.getAvailableIP() + nextIP, err := app.getAvailableIPs() c.Assert(err, check.IsNil) - c.Assert(nextIP.String(), check.Equals, expectedNextIP.String()) + c.Assert(len(nextIP), check.Equals, 1) + c.Assert(nextIP[0].String(), check.Equals, expectedNextIP.String()) // If we call get Available again, we should receive // the same IP, as it has not been reserved. - nextIP2, err := app.getAvailableIP() + nextIP2, err := app.getAvailableIPs() c.Assert(err, check.IsNil) - c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String()) + c.Assert(len(nextIP2), check.Equals, 1) + c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String()) } func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { - ip, err := app.getAvailableIP() + ips, err := app.getAvailableIPs() c.Assert(err, check.IsNil) expected := netaddr.MustParseIP("10.27.0.1") - c.Assert(ip.String(), check.Equals, expected.String()) + c.Assert(len(ips), check.Equals, 1) + c.Assert(ips[0].String(), check.Equals, expected.String()) namespace, err := app.CreateNamespace("test_ip") c.Assert(err, check.IsNil) @@ -156,8 +164,9 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { } app.db.Save(&machine) - ip2, err := app.getAvailableIP() + ips2, err := app.getAvailableIPs() c.Assert(err, check.IsNil) - c.Assert(ip2.String(), check.Equals, expected.String()) + c.Assert(len(ips2), check.Equals, 1) + c.Assert(ips2[0].String(), check.Equals, expected.String()) }