Lock allocation of IP address

current logic is not safe as it will allow an IP that isnt persisted to
the DB to be given out multiple times if machines joins in quick
succession.

This adds a lock around the "get ip" and machine registration and save
to DB so we ensure thiis isnt happning.

Currently this had to be done three places, which is silly, and outlined
in #294.
This commit is contained in:
Kristoffer Dalby 2022-02-24 13:18:18 +00:00
parent 189e883f91
commit eda0a9f88a
5 changed files with 37 additions and 21 deletions

5
api.go
View file

@ -574,6 +574,9 @@ func (h *Headscale) handleAuthKey(
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Authentication key was valid, proceeding to acquire IP addresses") Msg("Authentication key was valid, proceeding to acquire IP addresses")
h.ipAllocationMutex.Lock()
ips, err := h.getAvailableIPs() ips, err := h.getAvailableIPs()
if err != nil { if err != nil {
log.Error(). log.Error().
@ -602,6 +605,8 @@ func (h *Headscale) handleAuthKey(
machine.Registered = true machine.Registered = true
machine.RegisterMethod = RegisterMethodAuthKey machine.RegisterMethod = RegisterMethodAuthKey
h.db.Save(&machine) h.db.Save(&machine)
h.ipAllocationMutex.Unlock()
} }
pak.Used = true pak.Used = true

2
app.go
View file

@ -153,6 +153,8 @@ type Headscale struct {
oidcStateCache *cache.Cache oidcStateCache *cache.Cache
requestedExpiryCache *cache.Cache requestedExpiryCache *cache.Cache
ipAllocationMutex sync.Mutex
} }
// Look up the TLS constant relative to user-supplied TLS client // Look up the TLS constant relative to user-supplied TLS client

View file

@ -856,6 +856,9 @@ func (h *Headscale) RegisterMachine(
return nil, err return nil, err
} }
h.ipAllocationMutex.Lock()
defer h.ipAllocationMutex.Unlock()
ips, err := h.getAvailableIPs() ips, err := h.getAvailableIPs()
if err != nil { if err != nil {
log.Error(). log.Error().

View file

@ -317,6 +317,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
h.ipAllocationMutex.Lock()
ips, err := h.getAvailableIPs() ips, err := h.getAvailableIPs()
if err != nil { if err != nil {
log.Error(). log.Error().
@ -338,6 +340,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
machine.LastSuccessfulUpdate = &now machine.LastSuccessfulUpdate = &now
machine.Expiry = &requestedTime machine.Expiry = &requestedTime
h.db.Save(&machine) h.db.Save(&machine)
h.ipAllocationMutex.Unlock()
} }
var content bytes.Buffer var content bytes.Buffer

View file

@ -12,6 +12,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"sort"
"strings" "strings"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -157,9 +158,6 @@ func GetIPPrefixEndpoints(na netaddr.IPPrefix) (network, broadcast netaddr.IP) {
return return
} }
// TODO: Is this concurrency safe?
// What would happen if multiple hosts were to register at the same time?
// Would we attempt to assign the same addresses to multiple nodes?
func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) { func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) {
usedIps, err := h.getUsedIPs() usedIps, err := h.getUsedIPs()
if err != nil { if err != nil {
@ -179,7 +177,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro
switch { switch {
case ip.Compare(ipPrefixBroadcastAddress) == 0: case ip.Compare(ipPrefixBroadcastAddress) == 0:
fallthrough fallthrough
case containsIPs(usedIps, ip): case usedIps.Contains(ip):
fallthrough fallthrough
case ip.IsZero() || ip.IsLoopback(): case ip.IsZero() || ip.IsLoopback():
ip = ip.Next() ip = ip.Next()
@ -192,24 +190,38 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro
} }
} }
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) { func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) {
// FIXME: This really deserves a better data model, // FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough // but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet. // to begin experimenting with a dual stack tailnet.
var addressesSlices []string var addressesSlices []string
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
ips := make([]netaddr.IP, 0, len(h.cfg.IPPrefixes)*len(addressesSlices)) log.Trace().
Strs("addresses", addressesSlices).
Msg("Got allocated ip addresses from databases")
var ips netaddr.IPSetBuilder
for _, slice := range addressesSlices { for _, slice := range addressesSlices {
var a MachineAddresses var machineAddresses MachineAddresses
err := a.Scan(slice) err := machineAddresses.Scan(slice)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read ip from database: %w", err) return netaddr.IPSet{}, fmt.Errorf(
"failed to read ip from database: %w",
err,
)
}
for _, ip := range machineAddresses {
ips.Add(ip)
} }
ips = append(ips, a...)
} }
return ips, nil log.Trace().
Interface("addresses", ips).
Msg("Parsed ip addresses that has been allocated from databases")
return netaddr.IPSet{}, nil
} }
func containsString(ss []string, s string) bool { func containsString(ss []string, s string) bool {
@ -222,16 +234,6 @@ func containsString(ss []string, s string) bool {
return false return false
} }
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
for _, v := range ips {
if v == ip {
return true
}
}
return false
}
func tailNodesToString(nodes []*tailcfg.Node) string { func tailNodesToString(nodes []*tailcfg.Node) string {
temp := make([]string, len(nodes)) temp := make([]string, len(nodes))