diff --git a/.github/workflows/test-integration.yml b/.github/workflows/test-integration.yml index d9c52c7..b8fd85a 100644 --- a/.github/workflows/test-integration.yml +++ b/.github/workflows/test-integration.yml @@ -29,4 +29,4 @@ jobs: - name: Run Integration tests if: steps.changed-files.outputs.any_changed == 'true' - run: go test -tags integration -timeout 30m + run: make test_integration diff --git a/CHANGELOG.md b/CHANGELOG.md index d48de9a..26b575b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ - Boundaries between Namespaces has been removed and all nodes can communicate by default [#357](https://github.com/juanfont/headscale/pull/357) - To limit access between nodes, use [ACLs](./docs/acls.md). +**Changes**: + +- Fix a bug were the same IP could be assigned to multiple hosts if joined in quick succession [#346](https://github.com/juanfont/headscale/pull/346) + **0.14.0 (2022-02-24):** **UPCOMING BREAKING**: diff --git a/Makefile b/Makefile index 5214509..266dadb 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ test: @go test -coverprofile=coverage.out ./... test_integration: - go test -tags integration -timeout 30m -count=1 ./... + go test -failfast -tags integration -timeout 30m -count=1 ./... test_integration_cli: go test -tags integration -v integration_cli_test.go integration_common_test.go diff --git a/api.go b/api.go index 073be5e..bb5495a 100644 --- a/api.go +++ b/api.go @@ -574,6 +574,9 @@ func (h *Headscale) handleAuthKey( Str("func", "handleAuthKey"). Str("machine", machine.Name). Msg("Authentication key was valid, proceeding to acquire IP addresses") + + h.ipAllocationMutex.Lock() + ips, err := h.getAvailableIPs() if err != nil { log.Error(). @@ -602,6 +605,8 @@ func (h *Headscale) handleAuthKey( machine.Registered = true machine.RegisterMethod = RegisterMethodAuthKey h.db.Save(&machine) + + h.ipAllocationMutex.Unlock() } pak.Used = true diff --git a/app.go b/app.go index 26ec956..68d933c 100644 --- a/app.go +++ b/app.go @@ -153,6 +153,8 @@ type Headscale struct { oidcStateCache *cache.Cache requestedExpiryCache *cache.Cache + + ipAllocationMutex sync.Mutex } // Look up the TLS constant relative to user-supplied TLS client diff --git a/integration_test.go b/integration_test.go index 5024cd0..03d6d2f 100644 --- a/integration_test.go +++ b/integration_test.go @@ -15,6 +15,7 @@ import ( "os" "path" "strings" + "sync" "testing" "time" @@ -44,6 +45,8 @@ type IntegrationTestSuite struct { headscale dockertest.Resource namespaces map[string]TestNamespace + + joinWaitGroup sync.WaitGroup } func TestIntegrationTestSuite(t *testing.T) { @@ -118,7 +121,7 @@ func (s *IntegrationTestSuite) saveLog( return err } - fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath) + log.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath) err = ioutil.WriteFile( path.Join(basePath, resource.Container.Name+".stdout.log"), @@ -141,6 +144,34 @@ func (s *IntegrationTestSuite) saveLog( return nil } +func (s *IntegrationTestSuite) Join( + endpoint, key, hostname string, + tailscale dockertest.Resource, +) { + defer s.joinWaitGroup.Done() + + command := []string{ + "tailscale", + "up", + "-login-server", + endpoint, + "--authkey", + key, + "--hostname", + hostname, + } + + log.Println("Join command:", command) + log.Printf("Running join command for %s\n", hostname) + _, err := ExecuteCommand( + &tailscale, + command, + []string{}, + ) + assert.Nil(s.T(), err) + log.Printf("%s joined\n", hostname) +} + func (s *IntegrationTestSuite) tailscaleContainer( namespace, identifier, version string, ) (string, *dockertest.Resource) { @@ -178,7 +209,7 @@ func (s *IntegrationTestSuite) tailscaleContainer( if err != nil { log.Fatalf("Could not start resource: %s", err) } - fmt.Printf("Created %s container\n", hostname) + log.Printf("Created %s container\n", hostname) return hostname, pts } @@ -221,15 +252,15 @@ func (s *IntegrationTestSuite) SetupSuite() { Cmd: []string{"headscale", "serve"}, } - fmt.Println("Creating headscale container") + log.Println("Creating headscale container") if pheadscale, err := s.pool.BuildAndRunWithBuildOptions(headscaleBuildOptions, headscaleOptions, DockerRestartPolicy); err == nil { s.headscale = *pheadscale } else { log.Fatalf("Could not start resource: %s", err) } - fmt.Println("Created headscale container") + log.Println("Created headscale container") - fmt.Println("Creating tailscale containers") + log.Println("Creating tailscale containers") for namespace, scales := range s.namespaces { for i := 0; i < scales.count; i++ { version := tailscaleVersions[i%len(tailscaleVersions)] @@ -243,7 +274,7 @@ func (s *IntegrationTestSuite) SetupSuite() { } } - fmt.Println("Waiting for headscale to be ready") + log.Println("Waiting for headscale to be ready") hostEndpoint := fmt.Sprintf("localhost:%s", s.headscale.GetPort("8080/tcp")) if err := s.pool.Retry(func() error { @@ -266,19 +297,19 @@ func (s *IntegrationTestSuite) SetupSuite() { // https://github.com/stretchr/testify/issues/849 return // fmt.Errorf("Could not connect to headscale: %s", err) } - fmt.Println("headscale container is ready") + log.Println("headscale container is ready") for namespace, scales := range s.namespaces { - fmt.Printf("Creating headscale namespace: %s\n", namespace) + log.Printf("Creating headscale namespace: %s\n", namespace) result, err := ExecuteCommand( &s.headscale, []string{"headscale", "namespaces", "create", namespace}, []string{}, ) - fmt.Println("headscale create namespace result: ", result) + log.Println("headscale create namespace result: ", result) assert.Nil(s.T(), err) - fmt.Printf("Creating pre auth key for %s\n", namespace) + log.Printf("Creating pre auth key for %s\n", namespace) preAuthResult, err := ExecuteCommand( &s.headscale, []string{ @@ -304,33 +335,16 @@ func (s *IntegrationTestSuite) SetupSuite() { headscaleEndpoint := "http://headscale:8080" - fmt.Printf( + log.Printf( "Joining tailscale containers to headscale at %s\n", headscaleEndpoint, ) for hostname, tailscale := range scales.tailscales { - command := []string{ - "tailscale", - "up", - "-login-server", - headscaleEndpoint, - "--authkey", - preAuthKey.Key, - "--hostname", - hostname, - } - - fmt.Println("Join command:", command) - fmt.Printf("Running join command for %s\n", hostname) - result, err := ExecuteCommand( - &tailscale, - command, - []string{}, - ) - fmt.Println("tailscale result: ", result) - assert.Nil(s.T(), err) - fmt.Printf("%s joined\n", hostname) + s.joinWaitGroup.Add(1) + go s.Join(headscaleEndpoint, preAuthKey.Key, hostname, tailscale) } + + s.joinWaitGroup.Wait() } // The nodes need a bit of time to get their updated maps from headscale @@ -350,7 +364,7 @@ func (s *IntegrationTestSuite) HandleStats( func (s *IntegrationTestSuite) TestListNodes() { for namespace, scales := range s.namespaces { - fmt.Println("Listing nodes") + log.Println("Listing nodes") result, err := ExecuteCommand( &s.headscale, []string{"headscale", "--namespace", namespace, "nodes", "list"}, @@ -358,7 +372,7 @@ func (s *IntegrationTestSuite) TestListNodes() { ) assert.Nil(s.T(), err) - fmt.Printf("List nodes: \n%s\n", result) + log.Printf("List nodes: \n%s\n", result) // Chck that the correct count of host is present in node list lines := strings.Split(result, "\n") @@ -381,7 +395,7 @@ func (s *IntegrationTestSuite) TestGetIpAddresses() { s.T().Run(hostname, func(t *testing.T) { assert.NotNil(t, ip) - fmt.Printf("IP for %s: %s\n", hostname, ip) + log.Printf("IP for %s: %s\n", hostname, ip) // c.Assert(ip.Valid(), check.IsTrue) assert.True(t, ip.Is4() || ip.Is6()) @@ -410,7 +424,7 @@ func (s *IntegrationTestSuite) TestGetIpAddresses() { // s.T().Run(hostname, func(t *testing.T) { // command := []string{"tailscale", "status", "--json"} // -// fmt.Printf("Getting status for %s\n", hostname) +// log.Printf("Getting status for %s\n", hostname) // result, err := ExecuteCommand( // &tailscale, // command, @@ -477,7 +491,7 @@ func (s *IntegrationTestSuite) TestPingAllPeersByAddress() { ip.String(), } - fmt.Printf( + log.Printf( "Pinging from %s to %s (%s)\n", hostname, peername, @@ -489,7 +503,7 @@ func (s *IntegrationTestSuite) TestPingAllPeersByAddress() { []string{}, ) assert.Nil(t, err) - fmt.Printf("Result for %s: %s\n", hostname, result) + log.Printf("Result for %s: %s\n", hostname, result) assert.Contains(t, result, "pong") }) } @@ -512,6 +526,7 @@ func (s *IntegrationTestSuite) TestTailDrop() { } time.Sleep(sleepInverval) } + return } @@ -534,7 +549,7 @@ func (s *IntegrationTestSuite) TestTailDrop() { fmt.Sprintf("%s:", peername), } retry(10, 1*time.Second, func() error { - fmt.Printf( + log.Printf( "Sending file from %s to %s\n", hostname, peername, @@ -573,7 +588,7 @@ func (s *IntegrationTestSuite) TestTailDrop() { "ls", fmt.Sprintf("/tmp/file_from_%s", peername), } - fmt.Printf( + log.Printf( "Checking file in %s (%s) from %s (%s)\n", hostname, ips[hostname], @@ -586,7 +601,7 @@ func (s *IntegrationTestSuite) TestTailDrop() { []string{}, ) assert.Nil(t, err) - fmt.Printf("Result for %s: %s\n", peername, result) + log.Printf("Result for %s: %s\n", peername, result) assert.Equal( t, fmt.Sprintf("/tmp/file_from_%s\n", peername), @@ -616,7 +631,7 @@ func (s *IntegrationTestSuite) TestPingAllPeersByHostname() { fmt.Sprintf("%s.%s.headscale.net", peername, namespace), } - fmt.Printf( + log.Printf( "Pinging using hostname from %s to %s\n", hostname, peername, @@ -627,7 +642,7 @@ func (s *IntegrationTestSuite) TestPingAllPeersByHostname() { []string{}, ) assert.Nil(t, err) - fmt.Printf("Result for %s: %s\n", hostname, result) + log.Printf("Result for %s: %s\n", hostname, result) assert.Contains(t, result, "pong") }) } @@ -650,7 +665,7 @@ func (s *IntegrationTestSuite) TestMagicDNS() { fmt.Sprintf("%s.%s.headscale.net", peername, namespace), } - fmt.Printf( + log.Printf( "Resolving name %s from %s\n", peername, hostname, @@ -661,7 +676,7 @@ func (s *IntegrationTestSuite) TestMagicDNS() { []string{}, ) assert.Nil(t, err) - fmt.Printf("Result for %s: %s\n", hostname, result) + log.Printf("Result for %s: %s\n", hostname, result) for _, ip := range ips { assert.Contains(t, result, ip.String()) diff --git a/machine.go b/machine.go index 603441d..892fa3d 100644 --- a/machine.go +++ b/machine.go @@ -742,6 +742,9 @@ func (h *Headscale) RegisterMachine( return nil, err } + h.ipAllocationMutex.Lock() + defer h.ipAllocationMutex.Unlock() + ips, err := h.getAvailableIPs() if err != nil { log.Error(). diff --git a/oidc.go b/oidc.go index a47863f..cd77d29 100644 --- a/oidc.go +++ b/oidc.go @@ -317,6 +317,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } + h.ipAllocationMutex.Lock() + ips, err := h.getAvailableIPs() if err != nil { log.Error(). @@ -338,6 +340,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { machine.LastSuccessfulUpdate = &now machine.Expiry = &requestedTime h.db.Save(&machine) + + h.ipAllocationMutex.Unlock() } var content bytes.Buffer diff --git a/utils.go b/utils.go index 3cee5e3..004bf30 100644 --- a/utils.go +++ b/utils.go @@ -157,9 +157,6 @@ func GetIPPrefixEndpoints(na netaddr.IPPrefix) (network, broadcast netaddr.IP) { return } -// TODO: Is this concurrency safe? -// What would happen if multiple hosts were to register at the same time? -// Would we attempt to assign the same addresses to multiple nodes? func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) { usedIps, err := h.getUsedIPs() if err != nil { @@ -179,7 +176,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro switch { case ip.Compare(ipPrefixBroadcastAddress) == 0: fallthrough - case containsIPs(usedIps, ip): + case usedIps.Contains(ip): fallthrough case ip.IsZero() || ip.IsLoopback(): ip = ip.Next() @@ -192,24 +189,46 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro } } -func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) { +func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) { // FIXME: This really deserves a better data model, // but this was quick to get running and it should be enough // to begin experimenting with a dual stack tailnet. var addressesSlices []string h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) - ips := make([]netaddr.IP, 0, len(h.cfg.IPPrefixes)*len(addressesSlices)) + log.Trace(). + Strs("addresses", addressesSlices). + Msg("Got allocated ip addresses from databases") + + var ips netaddr.IPSetBuilder for _, slice := range addressesSlices { - var a MachineAddresses - err := a.Scan(slice) + var machineAddresses MachineAddresses + err := machineAddresses.Scan(slice) if err != nil { - return nil, fmt.Errorf("failed to read ip from database: %w", err) + return &netaddr.IPSet{}, fmt.Errorf( + "failed to read ip from database: %w", + err, + ) + } + + for _, ip := range machineAddresses { + ips.Add(ip) } - ips = append(ips, a...) } - return ips, nil + log.Trace(). + Interface("addresses", ips). + Msg("Parsed ip addresses that has been allocated from databases") + + ipSet, err := ips.IPSet() + if err != nil { + return &netaddr.IPSet{}, fmt.Errorf( + "failed to build IP Set: %w", + err, + ) + } + + return ipSet, nil } func containsString(ss []string, s string) bool { @@ -222,16 +241,6 @@ func containsString(ss []string, s string) bool { return false } -func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool { - for _, v := range ips { - if v == ip { - return true - } - } - - return false -} - func tailNodesToString(nodes []*tailcfg.Node) string { temp := make([]string, len(nodes)) diff --git a/utils_test.go b/utils_test.go index feb44d5..896040c 100644 --- a/utils_test.go +++ b/utils_test.go @@ -48,9 +48,12 @@ func (s *Suite) TestGetUsedIps(c *check.C) { c.Assert(err, check.IsNil) expected := netaddr.MustParseIP("10.27.0.1") + expectedIPSetBuilder := netaddr.IPSetBuilder{} + expectedIPSetBuilder.Add(expected) + expectedIPSet, _ := expectedIPSetBuilder.IPSet() - c.Assert(len(usedIps), check.Equals, 1) - c.Assert(usedIps[0], check.Equals, expected) + c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true) + c.Assert(usedIps.Contains(expected), check.Equals, true) machine1, err := app.GetMachineByID(0) c.Assert(err, check.IsNil) @@ -64,6 +67,8 @@ func (s *Suite) TestGetMultiIp(c *check.C) { c.Assert(err, check.IsNil) for index := 1; index <= 350; index++ { + app.ipAllocationMutex.Lock() + ips, err := app.getAvailableIPs() c.Assert(err, check.IsNil) @@ -86,17 +91,30 @@ func (s *Suite) TestGetMultiIp(c *check.C) { IPAddresses: ips, } app.db.Save(&machine) + + app.ipAllocationMutex.Unlock() } usedIps, err := app.getUsedIPs() - c.Assert(err, check.IsNil) - c.Assert(len(usedIps), check.Equals, 350) + expected0 := netaddr.MustParseIP("10.27.0.1") + expected9 := netaddr.MustParseIP("10.27.0.10") + expected300 := netaddr.MustParseIP("10.27.0.45") - c.Assert(usedIps[0], check.Equals, netaddr.MustParseIP("10.27.0.1")) - c.Assert(usedIps[9], check.Equals, netaddr.MustParseIP("10.27.0.10")) - c.Assert(usedIps[300], check.Equals, netaddr.MustParseIP("10.27.1.45")) + notExpectedIPSetBuilder := netaddr.IPSetBuilder{} + notExpectedIPSetBuilder.Add(expected0) + notExpectedIPSetBuilder.Add(expected9) + notExpectedIPSetBuilder.Add(expected300) + notExpectedIPSet, err := notExpectedIPSetBuilder.IPSet() + c.Assert(err, check.IsNil) + + // We actually expect it to be a lot larger + c.Assert(usedIps.Equal(notExpectedIPSet), check.Equals, false) + + c.Assert(usedIps.Contains(expected0), check.Equals, true) + c.Assert(usedIps.Contains(expected9), check.Equals, true) + c.Assert(usedIps.Contains(expected300), check.Equals, true) // Check that we can read back the IPs machine1, err := app.GetMachineByID(1)