diff --git a/api.go b/api.go index c5002bb..50af552 100644 --- a/api.go +++ b/api.go @@ -375,13 +375,13 @@ func (h *Headscale) handleMachineExpired( Str("handler", "Registration"). Err(err). Msg("Cannot encode message") - machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name). + machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name). Inc() ctx.String(http.StatusInternalServerError, "") return } - machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name). + machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name). Inc() ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) } @@ -503,36 +503,46 @@ func (h *Headscale) handleAuthKey( return } - log.Debug(). - Str("func", "handleAuthKey"). - Str("machine", machine.Name). - Msg("Authentication key was valid, proceeding to acquire an IP address") - ip, err := h.getAvailableIP() - if err != nil { - log.Error(). + if machine.isRegistered() { + log.Trace(). + Caller(). + Str("machine", machine.Name). + Msg("machine already registered, reauthenticating") + + h.RefreshMachine(&machine, reqisterRequest.Expiry) + } else { + log.Debug(). Str("func", "handleAuthKey"). Str("machine", machine.Name). - Msg("Failed to find an available IP") - machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). - Inc() + Msg("Authentication key was valid, proceeding to acquire an IP address") + ip, err := h.getAvailableIP() + if err != nil { + log.Error(). + Str("func", "handleAuthKey"). + Str("machine", machine.Name). + Msg("Failed to find an available IP") + machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). + Inc() - return + return + } + log.Info(). + Str("func", "handleAuthKey"). + Str("machine", machine.Name). + Str("ip", ip.String()). + Msgf("Assigning %s to %s", ip, machine.Name) + + machine.Expiry = &reqisterRequest.Expiry + machine.AuthKeyID = uint(pak.ID) + machine.IPAddress = ip.String() + machine.NamespaceID = pak.NamespaceID + machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey). + HexString() + // we update it just in case + machine.Registered = true + machine.RegisterMethod = RegisterMethodAuthKey + h.db.Save(&machine) } - log.Info(). - Str("func", "handleAuthKey"). - Str("machine", machine.Name). - Str("ip", ip.String()). - Msgf("Assigning %s to %s", ip, machine.Name) - - machine.AuthKeyID = uint(pak.ID) - machine.IPAddress = ip.String() - machine.NamespaceID = pak.NamespaceID - machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey). - HexString() - // we update it just in case - machine.Registered = true - machine.RegisterMethod = RegisterMethodAuthKey - h.db.Save(&machine) pak.Used = true h.db.Save(&pak) @@ -558,6 +568,6 @@ func (h *Headscale) handleAuthKey( log.Info(). Str("func", "handleAuthKey"). Str("machine", machine.Name). - Str("ip", ip.String()). + Str("ip", machine.IPAddress). Msg("Successfully authenticated via AuthKey") } diff --git a/machine.go b/machine.go index d6903bd..dd7124e 100644 --- a/machine.go +++ b/machine.go @@ -270,6 +270,15 @@ func (h *Headscale) ExpireMachine(machine *Machine) { h.db.Save(machine) } +// RefreshMachine takes a Machine struct and sets the expire field to now. +func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) { + now := time.Now() + + machine.LastSuccessfulUpdate = &now + machine.Expiry = &expiry + h.db.Save(machine) +} + // DeleteMachine softs deletes a Machine from the database. func (h *Headscale) DeleteMachine(machine *Machine) error { err := h.RemoveSharedMachineFromAllNamespaces(machine) @@ -644,6 +653,7 @@ func (h *Headscale) RegisterMachine( machine.NamespaceID = namespace.ID machine.Registered = true machine.RegisterMethod = RegisterMethodCLI + machine.Expiry = &requestedTime h.db.Save(&machine) log.Trace(). diff --git a/oidc.go b/oidc.go index cdd4fc2..2f7d6d6 100644 --- a/oidc.go +++ b/oidc.go @@ -81,6 +81,11 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { return } + log.Trace(). + Caller(). + Str("machine_key", machineKeyStr). + Msg("Received oidc register call") + randomBlob := make([]byte, randomByteSize) if _, err := rand.Read(randomBlob); err != nil { log.Error(). @@ -124,7 +129,11 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - log.Debug().Msgf("AccessToken: %v", oauth2Token.AccessToken) + log.Trace(). + Caller(). + Str("code", code). + Str("state", state). + Msg("Got oidc callback") rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) if !rawIDTokenOK { @@ -202,6 +211,29 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } + if machine.isRegistered() { + log.Trace(). + Caller(). + Str("machine", machine.Name). + Msg("machine already registered, reauthenticating") + + h.RefreshMachine(machine, requestedTime) + + ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` + + +

headscale

+

+ Reuthenticated as %s, you can now close this window. +

+ + + +`, claims.Email))) + + return + } + now := time.Now().UTC() if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok { @@ -258,6 +290,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { machine.Registered = true machine.RegisterMethod = RegisterMethodOIDC machine.LastSuccessfulUpdate = &now + machine.Expiry = &requestedTime h.db.Save(&machine) }