Lock allocation of IP address
current logic is not safe as it will allow an IP that isnt persisted to the DB to be given out multiple times if machines joins in quick succession. This adds a lock around the "get ip" and machine registration and save to DB so we ensure thiis isnt happning. Currently this had to be done three places, which is silly, and outlined in #294.
This commit is contained in:
parent
189e883f91
commit
eda0a9f88a
5 changed files with 37 additions and 21 deletions
5
api.go
5
api.go
|
@ -574,6 +574,9 @@ func (h *Headscale) handleAuthKey(
|
|||
Str("func", "handleAuthKey").
|
||||
Str("machine", machine.Name).
|
||||
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
||||
|
||||
h.ipAllocationMutex.Lock()
|
||||
|
||||
ips, err := h.getAvailableIPs()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -602,6 +605,8 @@ func (h *Headscale) handleAuthKey(
|
|||
machine.Registered = true
|
||||
machine.RegisterMethod = RegisterMethodAuthKey
|
||||
h.db.Save(&machine)
|
||||
|
||||
h.ipAllocationMutex.Unlock()
|
||||
}
|
||||
|
||||
pak.Used = true
|
||||
|
|
2
app.go
2
app.go
|
@ -153,6 +153,8 @@ type Headscale struct {
|
|||
oidcStateCache *cache.Cache
|
||||
|
||||
requestedExpiryCache *cache.Cache
|
||||
|
||||
ipAllocationMutex sync.Mutex
|
||||
}
|
||||
|
||||
// Look up the TLS constant relative to user-supplied TLS client
|
||||
|
|
|
@ -856,6 +856,9 @@ func (h *Headscale) RegisterMachine(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
h.ipAllocationMutex.Lock()
|
||||
defer h.ipAllocationMutex.Unlock()
|
||||
|
||||
ips, err := h.getAvailableIPs()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
|
4
oidc.go
4
oidc.go
|
@ -317,6 +317,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
h.ipAllocationMutex.Lock()
|
||||
|
||||
ips, err := h.getAvailableIPs()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -338,6 +340,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
machine.LastSuccessfulUpdate = &now
|
||||
machine.Expiry = &requestedTime
|
||||
h.db.Save(&machine)
|
||||
|
||||
h.ipAllocationMutex.Unlock()
|
||||
}
|
||||
|
||||
var content bytes.Buffer
|
||||
|
|
44
utils.go
44
utils.go
|
@ -12,6 +12,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -157,9 +158,6 @@ func GetIPPrefixEndpoints(na netaddr.IPPrefix) (network, broadcast netaddr.IP) {
|
|||
return
|
||||
}
|
||||
|
||||
// TODO: Is this concurrency safe?
|
||||
// What would happen if multiple hosts were to register at the same time?
|
||||
// Would we attempt to assign the same addresses to multiple nodes?
|
||||
func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) {
|
||||
usedIps, err := h.getUsedIPs()
|
||||
if err != nil {
|
||||
|
@ -179,7 +177,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro
|
|||
switch {
|
||||
case ip.Compare(ipPrefixBroadcastAddress) == 0:
|
||||
fallthrough
|
||||
case containsIPs(usedIps, ip):
|
||||
case usedIps.Contains(ip):
|
||||
fallthrough
|
||||
case ip.IsZero() || ip.IsLoopback():
|
||||
ip = ip.Next()
|
||||
|
@ -192,24 +190,38 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
|
||||
func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) {
|
||||
// FIXME: This really deserves a better data model,
|
||||
// but this was quick to get running and it should be enough
|
||||
// to begin experimenting with a dual stack tailnet.
|
||||
var addressesSlices []string
|
||||
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
|
||||
|
||||
ips := make([]netaddr.IP, 0, len(h.cfg.IPPrefixes)*len(addressesSlices))
|
||||
log.Trace().
|
||||
Strs("addresses", addressesSlices).
|
||||
Msg("Got allocated ip addresses from databases")
|
||||
|
||||
var ips netaddr.IPSetBuilder
|
||||
for _, slice := range addressesSlices {
|
||||
var a MachineAddresses
|
||||
err := a.Scan(slice)
|
||||
var machineAddresses MachineAddresses
|
||||
err := machineAddresses.Scan(slice)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read ip from database: %w", err)
|
||||
return netaddr.IPSet{}, fmt.Errorf(
|
||||
"failed to read ip from database: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, ip := range machineAddresses {
|
||||
ips.Add(ip)
|
||||
}
|
||||
ips = append(ips, a...)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
log.Trace().
|
||||
Interface("addresses", ips).
|
||||
Msg("Parsed ip addresses that has been allocated from databases")
|
||||
|
||||
return netaddr.IPSet{}, nil
|
||||
}
|
||||
|
||||
func containsString(ss []string, s string) bool {
|
||||
|
@ -222,16 +234,6 @@ func containsString(ss []string, s string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
|
||||
for _, v := range ips {
|
||||
if v == ip {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func tailNodesToString(nodes []*tailcfg.Node) string {
|
||||
temp := make([]string, len(nodes))
|
||||
|
||||
|
|
Loading…
Reference in a new issue