Register new machines needing callback in memory

This commit stores temporary registration data in cache, instead of
memory allowing us to only have actually registered machines in the
database.
This commit is contained in:
Kristoffer Dalby 2022-02-28 08:06:39 +00:00
parent 1caa6f5d69
commit 469551bc5d
7 changed files with 136 additions and 95 deletions

111
api.go
View file

@ -125,25 +125,40 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
machine, err := h.GetMachineByMachineKey(machineKey) machine, err := h.GetMachineByMachineKey(machineKey)
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
newMachine := Machine{
Expiry: &time.Time{}, machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
MachineKey: MachinePublicKeyStripPrefix(machineKey),
Name: req.Hostinfo.Hostname, // If the machine has AuthKey set, handle registration via PreAuthKeys
} if req.Auth.AuthKey != "" {
if err := h.db.Create(&newMachine).Error; err != nil { h.handleAuthKey(ctx, machineKey, req)
log.Error().
Caller().
Err(err).
Msg("Could not create row")
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
Inc()
return return
} }
machine = &newMachine
// The machine did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback
// happens
newMachine := Machine{
Expiry: &time.Time{},
MachineKey: machineKeyStr,
Name: req.Hostinfo.Hostname,
NodeKey: NodePublicKeyStripPrefix(req.NodeKey),
LastSeen: &now,
} }
if machine.Registered { h.registrationCache.Set(
machineKeyStr,
newMachine,
requestedExpiryCacheExpiration,
)
h.handleMachineRegistrationNew(ctx, machineKey, req)
return
}
// The machine is already registered, so we need to pass through reauth or key update.
if machine != nil {
// If the NodeKey stored in headscale is the same as the key presented in a registration // If the NodeKey stored in headscale is the same as the key presented in a registration
// request, then we have a node that is either: // request, then we have a node that is either:
// - Trying to log out (sending a expiry in the past) // - Trying to log out (sending a expiry in the past)
@ -180,15 +195,6 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
return return
} }
// If the machine has AuthKey set, handle registration via PreAuthKeys
if req.Auth.AuthKey != "" {
h.handleAuthKey(ctx, machineKey, req, *machine)
return
}
h.handleMachineRegistrationNew(ctx, machineKey, req, *machine)
} }
func (h *Headscale) getMapResponse( func (h *Headscale) getMapResponse(
@ -402,7 +408,7 @@ func (h *Headscale) handleMachineExpired(
Msg("Machine registration has expired. Sending a authurl to register") Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" { if registerRequest.Auth.AuthKey != "" {
h.handleAuthKey(ctx, machineKey, registerRequest, machine) h.handleAuthKey(ctx, machineKey, registerRequest)
return return
} }
@ -465,13 +471,12 @@ func (h *Headscale) handleMachineRegistrationNew(
ctx *gin.Context, ctx *gin.Context,
machineKey key.MachinePublic, machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine,
) { ) {
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
// The machine registration is new, redirect the client to the registration URL // The machine registration is new, redirect the client to the registration URL
log.Debug(). log.Debug().
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Msg("The node is sending us a new NodeKey, sending auth url") Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" { if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf( resp.AuthURL = fmt.Sprintf(
@ -487,7 +492,7 @@ func (h *Headscale) handleMachineRegistrationNew(
if !registerRequest.Expiry.IsZero() { if !registerRequest.Expiry.IsZero() {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Time("expiry", registerRequest.Expiry). Time("expiry", registerRequest.Expiry).
Msg("Non-zero expiry time requested, adding to cache") Msg("Non-zero expiry time requested, adding to cache")
h.requestedExpiryCache.Set( h.requestedExpiryCache.Set(
@ -497,11 +502,6 @@ func (h *Headscale) handleMachineRegistrationNew(
) )
} }
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
// save the NodeKey
h.db.Save(&machine)
respBody, err := encode(resp, &machineKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -520,19 +520,21 @@ func (h *Headscale) handleAuthKey(
ctx *gin.Context, ctx *gin.Context,
machineKey key.MachinePublic, machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine,
) { ) {
machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
log.Debug(). log.Debug().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey) pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false resp.MachineAuthorized = false
@ -541,46 +543,40 @@ func (h *Headscale) handleAuthKey(
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
ctx.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc() Inc()
return return
} }
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc() Inc()
return return
} }
if machine.isRegistered() {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("machine already registered, reauthenticating")
h.RefreshMachine(&machine, registerRequest.Expiry)
} else {
log.Debug(). log.Debug().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses") Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey) nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
now := time.Now().UTC() now := time.Now().UTC()
_, err = h.RegisterMachine( machine, err := h.RegisterMachine(
machine.Name, registerRequest.Hostinfo.Hostname,
machine.Namespace.Name, machineKeyStr,
pak.Namespace.Name,
RegisterMethodAuthKey, RegisterMethodAuthKey,
&registerRequest.Expiry, &registerRequest.Expiry,
pak, pak,
@ -592,7 +588,8 @@ func (h *Headscale) handleAuthKey(
Caller(). Caller().
Err(err). Err(err).
Msg("could not register machine") Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name).Inc() machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
ctx.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not register machine", "could not register machine",
@ -600,10 +597,8 @@ func (h *Headscale) handleAuthKey(
return return
} }
}
pak.Used = true h.UsePreAuthKey(pak)
h.db.Save(&pak)
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser() resp.User = *pak.Namespace.toUser()
@ -612,21 +607,21 @@ func (h *Headscale) handleAuthKey(
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc() Inc()
ctx.String(http.StatusInternalServerError, "Extremely sad!") ctx.String(http.StatusInternalServerError, "Extremely sad!")
return return
} }
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name).
Inc() Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
log.Info(). log.Info().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")). Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey") Msg("Successfully authenticated via AuthKey")
} }

9
app.go
View file

@ -154,6 +154,8 @@ type Headscale struct {
requestedExpiryCache *cache.Cache requestedExpiryCache *cache.Cache
registrationCache *cache.Cache
ipAllocationMutex sync.Mutex ipAllocationMutex sync.Mutex
} }
@ -207,6 +209,12 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
requestedExpiryCacheCleanupInterval, requestedExpiryCacheCleanupInterval,
) )
registrationCache := cache.New(
// TODO(kradalby): Add unified cache expiry config options
requestedExpiryCacheExpiration,
requestedExpiryCacheCleanupInterval,
)
app := Headscale{ app := Headscale{
cfg: cfg, cfg: cfg,
dbType: cfg.DBtype, dbType: cfg.DBtype,
@ -214,6 +222,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
privateKey: privKey, privateKey: privKey,
aclRules: tailcfg.FilterAllowAll, // default allowall aclRules: tailcfg.FilterAllowAll, // default allowall
requestedExpiryCache: requestedExpiryCache, requestedExpiryCache: requestedExpiryCache,
registrationCache: registrationCache,
} }
err = app.initDB() err = app.initDB()

View file

@ -30,6 +30,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
machineAfterRegistering, err := app.RegisterMachine( machineAfterRegistering, err := app.RegisterMachine(
"testmachine",
machine.MachineKey, machine.MachineKey,
namespace.Name, namespace.Name,
RegisterMethodCLI, RegisterMethodCLI,

View file

@ -174,12 +174,11 @@ func (api headscaleV1APIServer) RegisterMachine(
} }
} }
machine, err := api.h.RegisterMachine( machine, err := api.h.RegisterMachineFromAuthCallback(
request.GetKey(), request.GetKey(),
request.GetNamespace(), request.GetNamespace(),
RegisterMethodCLI, RegisterMethodCLI,
&requestedTime, &requestedTime,
nil, nil, nil,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -24,6 +24,10 @@ const (
errMachineAlreadyRegistered = Error("machine already registered") errMachineAlreadyRegistered = Error("machine already registered")
errMachineRouteIsNotAvailable = Error("route is not available on machine") errMachineRouteIsNotAvailable = Error("route is not available on machine")
errMachineAddressesInvalid = Error("failed to parse machine addresses") errMachineAddressesInvalid = Error("failed to parse machine addresses")
errMachineNotFoundRegistrationCache = Error(
"machine not found in registration cache",
)
errCouldNotConvertMachineInterface = Error("failed to convert machine interface")
errHostnameTooLong = Error("Hostname too long") errHostnameTooLong = Error("Hostname too long")
) )
@ -686,14 +690,44 @@ func (machine *Machine) toProto() *v1.Machine {
return machineProto return machineProto
} }
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. func (h *Headscale) RegisterMachineFromAuthCallback(
func (h *Headscale) RegisterMachine(
machineKeyStr string, machineKeyStr string,
namespaceName string, namespaceName string,
registrationMethod string, registrationMethod string,
expiry *time.Time,
) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
if registrationMachine, ok := machineInterface.(Machine); ok {
machine, err := h.RegisterMachine(
registrationMachine.Name,
machineKeyStr,
namespaceName,
registrationMethod,
expiry,
nil,
&registrationMachine.NodeKey,
registrationMachine.LastSeen,
)
return machine, err
} else {
return nil, errCouldNotConvertMachineInterface
}
}
return nil, errMachineNotFoundRegistrationCache
}
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
func (h *Headscale) RegisterMachine(
machineName string,
machineKeyStr string,
namespaceName string,
registrationMethod string,
expiry *time.Time,
// Optionals // Optionals
expiry *time.Time,
authKey *PreAuthKey, authKey *PreAuthKey,
nodePublicKey *string, nodePublicKey *string,
lastSeen *time.Time, lastSeen *time.Time,
@ -768,6 +802,7 @@ func (h *Headscale) RegisterMachine(
machine.LastSeen = lastSeen machine.LastSeen = lastSeen
} }
machine.Name = machineName
machine.NamespaceID = namespace.ID machine.NamespaceID = namespace.ID
// TODO(kradalby): This field is uneccessary metadata, // TODO(kradalby): This field is uneccessary metadata,
@ -780,6 +815,7 @@ func (h *Headscale) RegisterMachine(
// Let us simplify the model, a machine is _only_ saved if // Let us simplify the model, a machine is _only_ saved if
// it is registered. // it is registered.
machine.Registered = true machine.Registered = true
h.db.Save(&machine) h.db.Save(&machine)
log.Trace(). log.Trace().

View file

@ -279,8 +279,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
now := time.Now().UTC()
namespaceName, err := NormalizeNamespaceName( namespaceName, err := NormalizeNamespaceName(
claims.Email, claims.Email,
h.cfg.OIDC.StripEmaildomain, h.cfg.OIDC.StripEmaildomain,
@ -328,14 +326,11 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
_, err = h.RegisterMachine( _, err = h.RegisterMachineFromAuthCallback(
machineKeyStr, machineKeyStr,
namespace.Name, namespace.Name,
RegisterMethodOIDC, RegisterMethodOIDC,
&requestedTime, &requestedTime,
nil,
nil,
&now,
) )
if err != nil { if err != nil {
log.Error(). log.Error().

View file

@ -113,6 +113,12 @@ func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
return nil return nil
} }
// UsePreAuthKey marks a PreAuthKey as used.
func (h *Headscale) UsePreAuthKey(k *PreAuthKey) {
k.Used = true
h.db.Save(k)
}
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used. // If returns no error and a PreAuthKey, it can be used.
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {