Lock allocation of IP address

current logic is not safe as it will allow an IP that isnt persisted to
the DB to be given out multiple times if machines joins in quick
succession.

This adds a lock around the "get ip" and machine registration and save
to DB so we ensure thiis isnt happning.

Currently this had to be done three places, which is silly, and outlined
in #294.
This commit is contained in:
Kristoffer Dalby 2022-02-24 13:18:18 +00:00
parent 189e883f91
commit eda0a9f88a
5 changed files with 37 additions and 21 deletions

5
api.go
View file

@ -574,6 +574,9 @@ func (h *Headscale) handleAuthKey(
Str("func", "handleAuthKey").
Str("machine", machine.Name).
Msg("Authentication key was valid, proceeding to acquire IP addresses")
h.ipAllocationMutex.Lock()
ips, err := h.getAvailableIPs()
if err != nil {
log.Error().
@ -602,6 +605,8 @@ func (h *Headscale) handleAuthKey(
machine.Registered = true
machine.RegisterMethod = RegisterMethodAuthKey
h.db.Save(&machine)
h.ipAllocationMutex.Unlock()
}
pak.Used = true

2
app.go
View file

@ -153,6 +153,8 @@ type Headscale struct {
oidcStateCache *cache.Cache
requestedExpiryCache *cache.Cache
ipAllocationMutex sync.Mutex
}
// Look up the TLS constant relative to user-supplied TLS client

View file

@ -856,6 +856,9 @@ func (h *Headscale) RegisterMachine(
return nil, err
}
h.ipAllocationMutex.Lock()
defer h.ipAllocationMutex.Unlock()
ips, err := h.getAvailableIPs()
if err != nil {
log.Error().

View file

@ -317,6 +317,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return
}
h.ipAllocationMutex.Lock()
ips, err := h.getAvailableIPs()
if err != nil {
log.Error().
@ -338,6 +340,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
machine.LastSuccessfulUpdate = &now
machine.Expiry = &requestedTime
h.db.Save(&machine)
h.ipAllocationMutex.Unlock()
}
var content bytes.Buffer

View file

@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"net"
"sort"
"strings"
"github.com/rs/zerolog/log"
@ -157,9 +158,6 @@ func GetIPPrefixEndpoints(na netaddr.IPPrefix) (network, broadcast netaddr.IP) {
return
}
// TODO: Is this concurrency safe?
// What would happen if multiple hosts were to register at the same time?
// Would we attempt to assign the same addresses to multiple nodes?
func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) {
usedIps, err := h.getUsedIPs()
if err != nil {
@ -179,7 +177,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro
switch {
case ip.Compare(ipPrefixBroadcastAddress) == 0:
fallthrough
case containsIPs(usedIps, ip):
case usedIps.Contains(ip):
fallthrough
case ip.IsZero() || ip.IsLoopback():
ip = ip.Next()
@ -192,24 +190,38 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro
}
}
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) {
// FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet.
var addressesSlices []string
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
ips := make([]netaddr.IP, 0, len(h.cfg.IPPrefixes)*len(addressesSlices))
log.Trace().
Strs("addresses", addressesSlices).
Msg("Got allocated ip addresses from databases")
var ips netaddr.IPSetBuilder
for _, slice := range addressesSlices {
var a MachineAddresses
err := a.Scan(slice)
var machineAddresses MachineAddresses
err := machineAddresses.Scan(slice)
if err != nil {
return nil, fmt.Errorf("failed to read ip from database: %w", err)
}
ips = append(ips, a...)
return netaddr.IPSet{}, fmt.Errorf(
"failed to read ip from database: %w",
err,
)
}
return ips, nil
for _, ip := range machineAddresses {
ips.Add(ip)
}
}
log.Trace().
Interface("addresses", ips).
Msg("Parsed ip addresses that has been allocated from databases")
return netaddr.IPSet{}, nil
}
func containsString(ss []string, s string) bool {
@ -222,16 +234,6 @@ func containsString(ss []string, s string) bool {
return false
}
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
for _, v := range ips {
if v == ip {
return true
}
}
return false
}
func tailNodesToString(nodes []*tailcfg.Node) string {
temp := make([]string, len(nodes))