diff --git a/api.go b/api.go index 9ed065f..fc27e46 100644 --- a/api.go +++ b/api.go @@ -413,7 +413,17 @@ func (h *Headscale) handleMachineLogOut( Str("machine", machine.Hostname). Msg("Client requested logout") - h.ExpireMachine(&machine) + err := h.ExpireMachine(&machine) + if err != nil { + log.Error(). + Caller(). + Str("func", "handleMachineLogOut"). + Err(err). + Msg("Failed to expire machine") + http.Error(writer, "Internal server error", http.StatusInternalServerError) + + return + } resp.AuthURL = "" resp.MachineAuthorized = false @@ -716,7 +726,16 @@ func (h *Headscale) handleAuthKey( machine.NodeKey = nodeKey machine.AuthKeyID = uint(pak.ID) - h.RefreshMachine(machine, registerRequest.Expiry) + err := h.RefreshMachine(machine, registerRequest.Expiry) + if err != nil { + log.Error(). + Caller(). + Str("machine", machine.Hostname). + Err(err). + Msg("Failed to refresh machine") + + return + } } else { now := time.Now().UTC() @@ -759,7 +778,18 @@ func (h *Headscale) handleAuthKey( } } - h.UsePreAuthKey(pak) + err = h.UsePreAuthKey(pak) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to use pre-auth key") + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). + Inc() + http.Error(writer, "Internal server error", http.StatusInternalServerError) + + return + } resp.MachineAuthorized = true resp.User = *pak.Namespace.toUser() diff --git a/machine_test.go b/machine_test.go index 0287b0c..a06d0db 100644 --- a/machine_test.go +++ b/machine_test.go @@ -249,10 +249,12 @@ func (s *Suite) TestExpireMachine(c *check.C) { machineFromDB, err := app.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) + c.Assert(machineFromDB, check.NotNil) c.Assert(machineFromDB.isExpired(), check.Equals, false) - app.ExpireMachine(machineFromDB) + err = app.ExpireMachine(machineFromDB) + c.Assert(err, check.IsNil) c.Assert(machineFromDB.isExpired(), check.Equals, true) } diff --git a/oidc.go b/oidc.go index 09365a4..8b5f024 100644 --- a/oidc.go +++ b/oidc.go @@ -345,7 +345,16 @@ func (h *Headscale) OIDCCallback( Str("machine", machine.Hostname). Msg("machine already registered, reauthenticating") - h.RefreshMachine(machine, time.Time{}) + err := h.RefreshMachine(machine, time.Time{}) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to refresh machine") + http.Error(writer, "Failed to refresh machine", http.StatusInternalServerError) + + return + } var content bytes.Buffer if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ @@ -373,7 +382,7 @@ func (h *Headscale) OIDCCallback( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err := writer.Write(content.Bytes()) + _, err = writer.Write(content.Bytes()) if err != nil { log.Error(). Caller().