diff --git a/api.go b/api.go index 073be5e..bb5495a 100644 --- a/api.go +++ b/api.go @@ -574,6 +574,9 @@ func (h *Headscale) handleAuthKey( Str("func", "handleAuthKey"). Str("machine", machine.Name). Msg("Authentication key was valid, proceeding to acquire IP addresses") + + h.ipAllocationMutex.Lock() + ips, err := h.getAvailableIPs() if err != nil { log.Error(). @@ -602,6 +605,8 @@ func (h *Headscale) handleAuthKey( machine.Registered = true machine.RegisterMethod = RegisterMethodAuthKey h.db.Save(&machine) + + h.ipAllocationMutex.Unlock() } pak.Used = true diff --git a/app.go b/app.go index 26ec956..68d933c 100644 --- a/app.go +++ b/app.go @@ -153,6 +153,8 @@ type Headscale struct { oidcStateCache *cache.Cache requestedExpiryCache *cache.Cache + + ipAllocationMutex sync.Mutex } // Look up the TLS constant relative to user-supplied TLS client diff --git a/machine.go b/machine.go index 3c704ad..7de99a6 100644 --- a/machine.go +++ b/machine.go @@ -856,6 +856,9 @@ func (h *Headscale) RegisterMachine( return nil, err } + h.ipAllocationMutex.Lock() + defer h.ipAllocationMutex.Unlock() + ips, err := h.getAvailableIPs() if err != nil { log.Error(). diff --git a/oidc.go b/oidc.go index a47863f..cd77d29 100644 --- a/oidc.go +++ b/oidc.go @@ -317,6 +317,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } + h.ipAllocationMutex.Lock() + ips, err := h.getAvailableIPs() if err != nil { log.Error(). @@ -338,6 +340,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { machine.LastSuccessfulUpdate = &now machine.Expiry = &requestedTime h.db.Save(&machine) + + h.ipAllocationMutex.Unlock() } var content bytes.Buffer diff --git a/utils.go b/utils.go index 3cee5e3..c1a39bb 100644 --- a/utils.go +++ b/utils.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "net" + "sort" "strings" "github.com/rs/zerolog/log" @@ -157,9 +158,6 @@ func GetIPPrefixEndpoints(na netaddr.IPPrefix) (network, broadcast netaddr.IP) { return } -// TODO: Is this concurrency safe? -// What would happen if multiple hosts were to register at the same time? -// Would we attempt to assign the same addresses to multiple nodes? func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) { usedIps, err := h.getUsedIPs() if err != nil { @@ -179,7 +177,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro switch { case ip.Compare(ipPrefixBroadcastAddress) == 0: fallthrough - case containsIPs(usedIps, ip): + case usedIps.Contains(ip): fallthrough case ip.IsZero() || ip.IsLoopback(): ip = ip.Next() @@ -192,24 +190,38 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro } } -func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) { +func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) { // FIXME: This really deserves a better data model, // but this was quick to get running and it should be enough // to begin experimenting with a dual stack tailnet. var addressesSlices []string h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) - ips := make([]netaddr.IP, 0, len(h.cfg.IPPrefixes)*len(addressesSlices)) + log.Trace(). + Strs("addresses", addressesSlices). + Msg("Got allocated ip addresses from databases") + + var ips netaddr.IPSetBuilder for _, slice := range addressesSlices { - var a MachineAddresses - err := a.Scan(slice) + var machineAddresses MachineAddresses + err := machineAddresses.Scan(slice) if err != nil { - return nil, fmt.Errorf("failed to read ip from database: %w", err) + return netaddr.IPSet{}, fmt.Errorf( + "failed to read ip from database: %w", + err, + ) + } + + for _, ip := range machineAddresses { + ips.Add(ip) } - ips = append(ips, a...) } - return ips, nil + log.Trace(). + Interface("addresses", ips). + Msg("Parsed ip addresses that has been allocated from databases") + + return netaddr.IPSet{}, nil } func containsString(ss []string, s string) bool { @@ -222,16 +234,6 @@ func containsString(ss []string, s string) bool { return false } -func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool { - for _, v := range ips { - if v == ip { - return true - } - } - - return false -} - func tailNodesToString(nodes []*tailcfg.Node) string { temp := make([]string, len(nodes))