diff --git a/poll.go b/poll.go index 239f260..1d21508 100644 --- a/poll.go +++ b/poll.go @@ -2,13 +2,14 @@ package headscale import ( "context" + "encoding/json" "errors" "fmt" "io" "net/http" "time" - "github.com/gin-gonic/gin" + "github.com/gorilla/mux" "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -33,13 +34,25 @@ const machineNameContextKey = contextKey("machineName") // only after their first request (marked with the ReadOnly field). // // At this moment the updates are sent in a quite horrendous way, but they kinda work. -func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { +func (h *Headscale) PollNetMapHandler( + w http.ResponseWriter, + r *http.Request, +) { + vars := mux.Vars(r) + machineKeyStr, ok := vars["mkey"] + if !ok || machineKeyStr == "" { + log.Error(). + Str("handler", "PollNetMap"). + Msg("No machine key in request") + http.Error(w, "No machine key in request", http.StatusBadRequest) + + return + } log.Trace(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Msg("PollNetMapHandler called") - body, _ := io.ReadAll(ctx.Request.Body) - machineKeyStr := ctx.Param("id") + body, _ := io.ReadAll(r.Body) var machineKey key.MachinePublic err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) @@ -48,7 +61,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Str("handler", "PollNetMap"). Err(err). Msg("Cannot parse client key") - ctx.String(http.StatusBadRequest, "") + + http.Error(w, "Cannot parse client key", http.StatusBadRequest) return } @@ -59,7 +73,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Str("handler", "PollNetMap"). Err(err). Msg("Cannot decode message") - ctx.String(http.StatusBadRequest, "") + http.Error(w, "Cannot decode message", http.StatusBadRequest) return } @@ -70,20 +84,21 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { log.Warn(). Str("handler", "PollNetMap"). Msgf("Ignoring request, cannot find machine with key %s", machineKey.String()) - ctx.String(http.StatusUnauthorized, "") + + http.Error(w, "", http.StatusUnauthorized) return } log.Error(). Str("handler", "PollNetMap"). Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String()) - ctx.String(http.StatusInternalServerError, "") + http.Error(w, "", http.StatusInternalServerError) return } log.Trace(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). Msg("Found machine in database") @@ -120,11 +135,11 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { if err != nil { log.Error(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). Err(err). Msg("Failed to persist/update machine in the database") - ctx.String(http.StatusInternalServerError, ":(") + http.Error(w, "", http.StatusInternalServerError) return } @@ -134,11 +149,11 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { if err != nil { log.Error(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). Err(err). Msg("Failed to get Map response") - ctx.String(http.StatusInternalServerError, ":(") + http.Error(w, "", http.StatusInternalServerError) return } @@ -150,7 +165,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 log.Debug(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). Bool("readOnly", req.ReadOnly). Bool("omitPeers", req.OmitPeers). @@ -162,7 +177,10 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Str("handler", "PollNetMap"). Str("machine", machine.Hostname). Msg("Client is starting up. Probably interested in a DERP map") - ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(data) return } @@ -177,7 +195,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { // Only create update channel if it has not been created log.Trace(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). Msg("Loading or creating update channel") @@ -194,8 +212,9 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Str("handler", "PollNetMap"). Str("machine", machine.Hostname). Msg("Client sent endpoint update and is ok with a response without peer list") - ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) - + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(data) // It sounds like we should update the nodes when we have received a endpoint update // even tho the comments in the tailscale code dont explicitly say so. updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "endpoint-update"). @@ -208,7 +227,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Str("handler", "PollNetMap"). Str("machine", machine.Hostname). Msg("Ignoring request, don't know how to handle it") - ctx.String(http.StatusBadRequest, "") + http.Error(w, "", http.StatusBadRequest) return } @@ -232,7 +251,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { updateChan <- struct{}{} h.PollNetMapStream( - ctx, + w, + r, machine, req, machineKey, @@ -242,7 +262,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { ) log.Trace(). Str("handler", "PollNetMap"). - Str("id", ctx.Param("id")). + Str("id", machineKeyStr). Str("machine", machine.Hostname). Msg("Finished stream, closing PollNetMap session") } @@ -251,7 +271,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { // stream logic, ensuring we communicate updates and data // to the connected clients. func (h *Headscale) PollNetMapStream( - ctx *gin.Context, + w http.ResponseWriter, + r *http.Request, machine *Machine, mapRequest tailcfg.MapRequest, machineKey key.MachinePublic, @@ -259,41 +280,21 @@ func (h *Headscale) PollNetMapStream( keepAliveChan chan []byte, updateChan chan struct{}, ) { - { - machine, err := h.GetMachineByMachineKey(machineKey) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - log.Warn(). - Str("handler", "PollNetMap"). - Msgf("Ignoring request, cannot find machine with key %s", machineKey.String()) - ctx.String(http.StatusUnauthorized, "") + ctx := context.WithValue(context.Background(), machineNameContextKey, machine.Hostname) - return - } - log.Error(). - Str("handler", "PollNetMap"). - Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String()) - ctx.String(http.StatusInternalServerError, "") + ctx, cancel := context.WithCancel(ctx) + defer cancel() - return - } + go h.scheduledPollWorker( + ctx, + updateChan, + keepAliveChan, + machineKey, + mapRequest, + machine, + ) - ctx := context.WithValue(ctx.Request.Context(), machineNameContextKey, machine.Hostname) - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - go h.scheduledPollWorker( - ctx, - updateChan, - keepAliveChan, - machineKey, - mapRequest, - machine, - ) - } - - ctx.Stream(func(writer io.Writer) bool { + for { log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", machine.Hostname). @@ -312,7 +313,7 @@ func (h *Headscale) PollNetMapStream( Str("channel", "pollData"). Int("bytes", len(data)). Msg("Sending data received via pollData channel") - _, err := writer.Write(data) + _, err := w.Write(data) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -321,7 +322,7 @@ func (h *Headscale) PollNetMapStream( Err(err). Msg("Cannot write data") - return false + break } log.Trace(). Str("handler", "PollNetMapStream"). @@ -343,7 +344,7 @@ func (h *Headscale) PollNetMapStream( // client has been removed from database // since the stream opened, terminate connection. - return false + break } now := time.Now().UTC() machine.LastSeen = &now @@ -369,7 +370,7 @@ func (h *Headscale) PollNetMapStream( Msg("Machine entry in database updated successfully after sending pollData") } - return true + break case data := <-keepAliveChan: log.Trace(). @@ -378,7 +379,7 @@ func (h *Headscale) PollNetMapStream( Str("channel", "keepAlive"). Int("bytes", len(data)). Msg("Sending keep alive message") - _, err := writer.Write(data) + _, err := w.Write(data) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -387,7 +388,7 @@ func (h *Headscale) PollNetMapStream( Err(err). Msg("Cannot write keep alive message") - return false + break } log.Trace(). Str("handler", "PollNetMapStream"). @@ -409,7 +410,7 @@ func (h *Headscale) PollNetMapStream( // client has been removed from database // since the stream opened, terminate connection. - return false + break } now := time.Now().UTC() machine.LastSeen = &now @@ -430,7 +431,7 @@ func (h *Headscale) PollNetMapStream( Msg("Machine updated successfully after sending keep alive") } - return true + break case <-updateChan: log.Trace(). @@ -460,7 +461,7 @@ func (h *Headscale) PollNetMapStream( Err(err). Msg("Could not get the map update") } - _, err = writer.Write(data) + _, err = w.Write(data) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -471,7 +472,7 @@ func (h *Headscale) PollNetMapStream( updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "failed"). Inc() - return false + return } log.Trace(). Str("handler", "PollNetMapStream"). @@ -499,7 +500,7 @@ func (h *Headscale) PollNetMapStream( // client has been removed from database // since the stream opened, terminate connection. - return false + return } now := time.Now().UTC() @@ -529,9 +530,9 @@ func (h *Headscale) PollNetMapStream( Msgf("%s is up to date", machine.Hostname) } - return true + return - case <-ctx.Request.Context().Done(): + case <-ctx.Done(): log.Info(). Str("handler", "PollNetMapStream"). Str("machine", machine.Hostname). @@ -550,7 +551,7 @@ func (h *Headscale) PollNetMapStream( // client has been removed from database // since the stream opened, terminate connection. - return false + break } now := time.Now().UTC() machine.LastSeen = &now @@ -564,9 +565,11 @@ func (h *Headscale) PollNetMapStream( Msg("Cannot update machine LastSeen") } - return false + break } - }) + } + + log.Info().Msgf("Closing poll loop to %s", machine.Hostname) } func (h *Headscale) scheduledPollWorker(