Lock allocation of IP address
current logic is not safe as it will allow an IP that isnt persisted to the DB to be given out multiple times if machines joins in quick succession. This adds a lock around the "get ip" and machine registration and save to DB so we ensure thiis isnt happning. Currently this had to be done three places, which is silly, and outlined in #294.
This commit is contained in:
parent
189e883f91
commit
eda0a9f88a
5 changed files with 37 additions and 21 deletions
5
api.go
5
api.go
|
@ -574,6 +574,9 @@ func (h *Headscale) handleAuthKey(
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
||||||
|
|
||||||
|
h.ipAllocationMutex.Lock()
|
||||||
|
|
||||||
ips, err := h.getAvailableIPs()
|
ips, err := h.getAvailableIPs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -602,6 +605,8 @@ func (h *Headscale) handleAuthKey(
|
||||||
machine.Registered = true
|
machine.Registered = true
|
||||||
machine.RegisterMethod = RegisterMethodAuthKey
|
machine.RegisterMethod = RegisterMethodAuthKey
|
||||||
h.db.Save(&machine)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
|
h.ipAllocationMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
pak.Used = true
|
pak.Used = true
|
||||||
|
|
2
app.go
2
app.go
|
@ -153,6 +153,8 @@ type Headscale struct {
|
||||||
oidcStateCache *cache.Cache
|
oidcStateCache *cache.Cache
|
||||||
|
|
||||||
requestedExpiryCache *cache.Cache
|
requestedExpiryCache *cache.Cache
|
||||||
|
|
||||||
|
ipAllocationMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up the TLS constant relative to user-supplied TLS client
|
// Look up the TLS constant relative to user-supplied TLS client
|
||||||
|
|
|
@ -856,6 +856,9 @@ func (h *Headscale) RegisterMachine(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.ipAllocationMutex.Lock()
|
||||||
|
defer h.ipAllocationMutex.Unlock()
|
||||||
|
|
||||||
ips, err := h.getAvailableIPs()
|
ips, err := h.getAvailableIPs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
|
4
oidc.go
4
oidc.go
|
@ -317,6 +317,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.ipAllocationMutex.Lock()
|
||||||
|
|
||||||
ips, err := h.getAvailableIPs()
|
ips, err := h.getAvailableIPs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -338,6 +340,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
machine.LastSuccessfulUpdate = &now
|
machine.LastSuccessfulUpdate = &now
|
||||||
machine.Expiry = &requestedTime
|
machine.Expiry = &requestedTime
|
||||||
h.db.Save(&machine)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
|
h.ipAllocationMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
var content bytes.Buffer
|
var content bytes.Buffer
|
||||||
|
|
46
utils.go
46
utils.go
|
@ -12,6 +12,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -157,9 +158,6 @@ func GetIPPrefixEndpoints(na netaddr.IPPrefix) (network, broadcast netaddr.IP) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Is this concurrency safe?
|
|
||||||
// What would happen if multiple hosts were to register at the same time?
|
|
||||||
// Would we attempt to assign the same addresses to multiple nodes?
|
|
||||||
func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) {
|
func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) {
|
||||||
usedIps, err := h.getUsedIPs()
|
usedIps, err := h.getUsedIPs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -179,7 +177,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro
|
||||||
switch {
|
switch {
|
||||||
case ip.Compare(ipPrefixBroadcastAddress) == 0:
|
case ip.Compare(ipPrefixBroadcastAddress) == 0:
|
||||||
fallthrough
|
fallthrough
|
||||||
case containsIPs(usedIps, ip):
|
case usedIps.Contains(ip):
|
||||||
fallthrough
|
fallthrough
|
||||||
case ip.IsZero() || ip.IsLoopback():
|
case ip.IsZero() || ip.IsLoopback():
|
||||||
ip = ip.Next()
|
ip = ip.Next()
|
||||||
|
@ -192,24 +190,38 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
|
func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) {
|
||||||
// FIXME: This really deserves a better data model,
|
// FIXME: This really deserves a better data model,
|
||||||
// but this was quick to get running and it should be enough
|
// but this was quick to get running and it should be enough
|
||||||
// to begin experimenting with a dual stack tailnet.
|
// to begin experimenting with a dual stack tailnet.
|
||||||
var addressesSlices []string
|
var addressesSlices []string
|
||||||
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
|
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
|
||||||
|
|
||||||
ips := make([]netaddr.IP, 0, len(h.cfg.IPPrefixes)*len(addressesSlices))
|
log.Trace().
|
||||||
|
Strs("addresses", addressesSlices).
|
||||||
|
Msg("Got allocated ip addresses from databases")
|
||||||
|
|
||||||
|
var ips netaddr.IPSetBuilder
|
||||||
for _, slice := range addressesSlices {
|
for _, slice := range addressesSlices {
|
||||||
var a MachineAddresses
|
var machineAddresses MachineAddresses
|
||||||
err := a.Scan(slice)
|
err := machineAddresses.Scan(slice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read ip from database: %w", err)
|
return netaddr.IPSet{}, fmt.Errorf(
|
||||||
}
|
"failed to read ip from database: %w",
|
||||||
ips = append(ips, a...)
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ips, nil
|
for _, ip := range machineAddresses {
|
||||||
|
ips.Add(ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Trace().
|
||||||
|
Interface("addresses", ips).
|
||||||
|
Msg("Parsed ip addresses that has been allocated from databases")
|
||||||
|
|
||||||
|
return netaddr.IPSet{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func containsString(ss []string, s string) bool {
|
func containsString(ss []string, s string) bool {
|
||||||
|
@ -222,16 +234,6 @@ func containsString(ss []string, s string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
|
|
||||||
for _, v := range ips {
|
|
||||||
if v == ip {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func tailNodesToString(nodes []*tailcfg.Node) string {
|
func tailNodesToString(nodes []*tailcfg.Node) string {
|
||||||
temp := make([]string, len(nodes))
|
temp := make([]string, len(nodes))
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue