diff --git a/api.go b/api.go index fc27e46..ff0de0c 100644 --- a/api.go +++ b/api.go @@ -30,6 +30,44 @@ const ( ) ) +func (h *Headscale) HealthHandler( + writer http.ResponseWriter, + req *http.Request, +) { + respond := func(err error) { + writer.Header().Set("Content-Type", "application/health+json; charset=utf-8") + + res := struct { + Status string `json:"status"` + }{ + Status: "pass", + } + + if err != nil { + writer.WriteHeader(http.StatusInternalServerError) + log.Error().Caller().Err(err).Msg("health check failed") + res.Status = "fail" + } + + buf, err := json.Marshal(res) + if err != nil { + log.Error().Caller().Err(err).Msg("marshal failed") + } + _, err = writer.Write(buf) + if err != nil { + log.Error().Caller().Err(err).Msg("write failed") + } + } + + if err := h.pingDB(); err != nil { + respond(err) + + return + } + + respond(nil) +} + // KeyHandler provides the Headscale pub key // Listens in /key. func (h *Headscale) KeyHandler( diff --git a/app.go b/app.go index e4e6910..f988048 100644 --- a/app.go +++ b/app.go @@ -423,19 +423,7 @@ func (h *Headscale) createPrometheusRouter() *gin.Engine { func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router { router := mux.NewRouter() - router.HandleFunc( - "/health", - func(writer http.ResponseWriter, req *http.Request) { - writer.WriteHeader(http.StatusOK) - _, err := writer.Write([]byte("{\"healthy\": \"ok\"}")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - }).Methods(http.MethodGet) - + router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet) router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).Methods(http.MethodPost) diff --git a/db.go b/db.go index e412468..5df9c23 100644 --- a/db.go +++ b/db.go @@ -1,6 +1,7 @@ package headscale import ( + "context" "database/sql/driver" "encoding/json" "errors" @@ -220,6 +221,17 @@ func (h *Headscale) setValue(key string, value string) error { return nil } +func (h *Headscale) pingDB() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + db, err := h.db.DB() + if err != nil { + return err + } + + return db.PingContext(ctx) +} + // This is a "wrapper" type around tailscales // Hostinfo to allow us to add database "serialization" // methods. This allows us to use a typed values throughout