From eda0a9f88a694c62afc858202d144e0a62019cf7 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 24 Feb 2022 13:18:18 +0000 Subject: [PATCH] Lock allocation of IP address current logic is not safe as it will allow an IP that isnt persisted to the DB to be given out multiple times if machines joins in quick succession. This adds a lock around the "get ip" and machine registration and save to DB so we ensure thiis isnt happning. Currently this had to be done three places, which is silly, and outlined in #294. --- api.go | 5 +++++ app.go | 2 ++ machine.go | 3 +++ oidc.go | 4 ++++ utils.go | 44 +++++++++++++++++++++++--------------------- 5 files changed, 37 insertions(+), 21 deletions(-) diff --git a/api.go b/api.go index 073be5e..bb5495a 100644 --- a/api.go +++ b/api.go @@ -574,6 +574,9 @@ func (h *Headscale) handleAuthKey( Str("func", "handleAuthKey"). Str("machine", machine.Name). Msg("Authentication key was valid, proceeding to acquire IP addresses") + + h.ipAllocationMutex.Lock() + ips, err := h.getAvailableIPs() if err != nil { log.Error(). @@ -602,6 +605,8 @@ func (h *Headscale) handleAuthKey( machine.Registered = true machine.RegisterMethod = RegisterMethodAuthKey h.db.Save(&machine) + + h.ipAllocationMutex.Unlock() } pak.Used = true diff --git a/app.go b/app.go index 26ec956..68d933c 100644 --- a/app.go +++ b/app.go @@ -153,6 +153,8 @@ type Headscale struct { oidcStateCache *cache.Cache requestedExpiryCache *cache.Cache + + ipAllocationMutex sync.Mutex } // Look up the TLS constant relative to user-supplied TLS client diff --git a/machine.go b/machine.go index 3c704ad..7de99a6 100644 --- a/machine.go +++ b/machine.go @@ -856,6 +856,9 @@ func (h *Headscale) RegisterMachine( return nil, err } + h.ipAllocationMutex.Lock() + defer h.ipAllocationMutex.Unlock() + ips, err := h.getAvailableIPs() if err != nil { log.Error(). diff --git a/oidc.go b/oidc.go index a47863f..cd77d29 100644 --- a/oidc.go +++ b/oidc.go @@ -317,6 +317,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } + h.ipAllocationMutex.Lock() + ips, err := h.getAvailableIPs() if err != nil { log.Error(). @@ -338,6 +340,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { machine.LastSuccessfulUpdate = &now machine.Expiry = &requestedTime h.db.Save(&machine) + + h.ipAllocationMutex.Unlock() } var content bytes.Buffer diff --git a/utils.go b/utils.go index 3cee5e3..c1a39bb 100644 --- a/utils.go +++ b/utils.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "net" + "sort" "strings" "github.com/rs/zerolog/log" @@ -157,9 +158,6 @@ func GetIPPrefixEndpoints(na netaddr.IPPrefix) (network, broadcast netaddr.IP) { return } -// TODO: Is this concurrency safe? -// What would happen if multiple hosts were to register at the same time? -// Would we attempt to assign the same addresses to multiple nodes? func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) { usedIps, err := h.getUsedIPs() if err != nil { @@ -179,7 +177,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro switch { case ip.Compare(ipPrefixBroadcastAddress) == 0: fallthrough - case containsIPs(usedIps, ip): + case usedIps.Contains(ip): fallthrough case ip.IsZero() || ip.IsLoopback(): ip = ip.Next() @@ -192,24 +190,38 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro } } -func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) { +func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) { // FIXME: This really deserves a better data model, // but this was quick to get running and it should be enough // to begin experimenting with a dual stack tailnet. var addressesSlices []string h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) - ips := make([]netaddr.IP, 0, len(h.cfg.IPPrefixes)*len(addressesSlices)) + log.Trace(). + Strs("addresses", addressesSlices). + Msg("Got allocated ip addresses from databases") + + var ips netaddr.IPSetBuilder for _, slice := range addressesSlices { - var a MachineAddresses - err := a.Scan(slice) + var machineAddresses MachineAddresses + err := machineAddresses.Scan(slice) if err != nil { - return nil, fmt.Errorf("failed to read ip from database: %w", err) + return netaddr.IPSet{}, fmt.Errorf( + "failed to read ip from database: %w", + err, + ) + } + + for _, ip := range machineAddresses { + ips.Add(ip) } - ips = append(ips, a...) } - return ips, nil + log.Trace(). + Interface("addresses", ips). + Msg("Parsed ip addresses that has been allocated from databases") + + return netaddr.IPSet{}, nil } func containsString(ss []string, s string) bool { @@ -222,16 +234,6 @@ func containsString(ss []string, s string) bool { return false } -func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool { - for _, v := range ips { - if v == ip { - return true - } - } - - return false -} - func tailNodesToString(nodes []*tailcfg.Node) string { temp := make([]string, len(nodes))