diff --git a/poll.go b/protocol_common_poll.go similarity index 84% rename from poll.go rename to protocol_common_poll.go index a51c936..988d225 100644 --- a/poll.go +++ b/protocol_common_poll.go @@ -2,17 +2,12 @@ package headscale import ( "context" - "errors" "fmt" - "io" "net/http" "time" - "github.com/gorilla/mux" "github.com/rs/zerolog/log" - "gorm.io/gorm" "tailscale.com/tailcfg" - "tailscale.com/types/key" ) const ( @@ -23,83 +18,13 @@ type contextKey string const machineNameContextKey = contextKey("machineName") -// PollNetMapHandler takes care of /machine/:id/map -// -// This is the busiest endpoint, as it keeps the HTTP long poll that updates -// the clients when something in the network changes. -// -// The clients POST stuff like HostInfo and their Endpoints here, but -// only after their first request (marked with the ReadOnly field). -// -// At this moment the updates are sent in a quite horrendous way, but they kinda work. -func (h *Headscale) PollNetMapHandler( +func (h *Headscale) handlePollCommon( writer http.ResponseWriter, req *http.Request, + machine *Machine, + mapRequest tailcfg.MapRequest, + isNoise bool, ) { - vars := mux.Vars(req) - machineKeyStr, ok := vars["mkey"] - if !ok || machineKeyStr == "" { - log.Error(). - Str("handler", "PollNetMap"). - Msg("No machine key in request") - http.Error(writer, "No machine key in request", http.StatusBadRequest) - - return - } - log.Trace(). - Str("handler", "PollNetMap"). - Str("id", machineKeyStr). - Msg("PollNetMapHandler called") - body, _ := io.ReadAll(req.Body) - - var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) - if err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Err(err). - Msg("Cannot parse client key") - - http.Error(writer, "Cannot parse client key", http.StatusBadRequest) - - return - } - mapRequest := tailcfg.MapRequest{} - err = decode(body, &mapRequest, &machineKey, h.privateKey) - if err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Err(err). - Msg("Cannot decode message") - http.Error(writer, "Cannot decode message", http.StatusBadRequest) - - return - } - - machine, err := h.GetMachineByMachineKey(machineKey) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - log.Warn(). - Str("handler", "PollNetMap"). - Msgf("Ignoring request, cannot find machine with key %s", machineKey.String()) - - http.Error(writer, "", http.StatusUnauthorized) - - return - } - log.Error(). - Str("handler", "PollNetMap"). - Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String()) - http.Error(writer, "", http.StatusInternalServerError) - - return - } - log.Trace(). - Str("handler", "PollNetMap"). - Str("id", machineKeyStr). - Str("machine", machine.Hostname). - Msg("Found machine in database") - machine.Hostname = mapRequest.Hostinfo.Hostname machine.HostInfo = HostInfo(*mapRequest.Hostinfo) machine.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) @@ -107,7 +32,7 @@ func (h *Headscale) PollNetMapHandler( // update ACLRules with peer informations (to update server tags if necessary) if h.aclPolicy != nil { - err = h.UpdateACLRules() + err := h.UpdateACLRules() if err != nil { log.Error(). Caller(). @@ -133,7 +58,7 @@ func (h *Headscale) PollNetMapHandler( if err != nil { log.Error(). Str("handler", "PollNetMap"). - Str("id", machineKeyStr). + Str("node_key", machine.NodeKey). Str("machine", machine.Hostname). Err(err). Msg("Failed to persist/update machine in the database") @@ -143,11 +68,11 @@ func (h *Headscale) PollNetMapHandler( } } - data, err := h.getLegacyMapResponseData(machineKey, mapRequest, machine) + mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise) if err != nil { log.Error(). Str("handler", "PollNetMap"). - Str("id", machineKeyStr). + Str("node_key", machine.NodeKey). Str("machine", machine.Hostname). Err(err). Msg("Failed to get Map response") @@ -163,7 +88,6 @@ func (h *Headscale) PollNetMapHandler( // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 log.Debug(). Str("handler", "PollNetMap"). - Str("id", machineKeyStr). Str("machine", machine.Hostname). Bool("readOnly", mapRequest.ReadOnly). Bool("omitPeers", mapRequest.OmitPeers). @@ -178,7 +102,7 @@ func (h *Headscale) PollNetMapHandler( writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err := writer.Write(data) + _, err := writer.Write(mapResp) if err != nil { log.Error(). Caller(). @@ -186,6 +110,10 @@ func (h *Headscale) PollNetMapHandler( Msg("Failed to write response") } + if f, ok := writer.(http.Flusher); ok { + f.Flush() + } + return } @@ -198,8 +126,7 @@ func (h *Headscale) PollNetMapHandler( // Only create update channel if it has not been created log.Trace(). - Str("handler", "PollNetMap"). - Str("id", machineKeyStr). + Caller(). Str("machine", machine.Hostname). Msg("Loading or creating update channel") @@ -218,7 +145,7 @@ func (h *Headscale) PollNetMapHandler( Msg("Client sent endpoint update and is ok with a response without peer list") writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err := writer.Write(data) + _, err := writer.Write(mapResp) if err != nil { log.Error(). Caller(). @@ -250,7 +177,7 @@ func (h *Headscale) PollNetMapHandler( Str("handler", "PollNetMap"). Str("machine", machine.Hostname). Msg("Sending initial map") - pollDataChan <- data + pollDataChan <- mapResp log.Info(). Str("handler", "PollNetMap"). @@ -260,35 +187,34 @@ func (h *Headscale) PollNetMapHandler( Inc() updateChan <- struct{}{} - h.PollNetMapStream( + h.pollNetMapStream( writer, req, machine, mapRequest, - machineKey, pollDataChan, keepAliveChan, updateChan, + isNoise, ) + log.Trace(). Str("handler", "PollNetMap"). - Str("id", machineKeyStr). Str("machine", machine.Hostname). Msg("Finished stream, closing PollNetMap session") } -// PollNetMapStream takes care of /machine/:id/map -// stream logic, ensuring we communicate updates and data -// to the connected clients. -func (h *Headscale) PollNetMapStream( +// pollNetMapStream stream logic for /machine/map, +// ensuring we communicate updates and data to the connected clients. +func (h *Headscale) pollNetMapStream( writer http.ResponseWriter, req *http.Request, machine *Machine, mapRequest tailcfg.MapRequest, - machineKey key.MachinePublic, pollDataChan chan []byte, keepAliveChan chan []byte, updateChan chan struct{}, + isNoise bool, ) { h.pollNetMapStreamWG.Add(1) defer h.pollNetMapStreamWG.Done() @@ -302,9 +228,9 @@ func (h *Headscale) PollNetMapStream( ctx, updateChan, keepAliveChan, - machineKey, mapRequest, machine, + isNoise, ) log.Trace(). @@ -491,7 +417,7 @@ func (h *Headscale) PollNetMapStream( Time("last_successful_update", lastUpdate). Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)). Msgf("There has been updates since the last successful update to %s", machine.Hostname) - data, err := h.getLegacyMapResponseData(machineKey, mapRequest, machine) + data, err := h.getMapResponseData(mapRequest, machine, false) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -637,9 +563,9 @@ func (h *Headscale) scheduledPollWorker( ctx context.Context, updateChan chan struct{}, keepAliveChan chan []byte, - machineKey key.MachinePublic, mapRequest tailcfg.MapRequest, machine *Machine, + isNoise bool, ) { keepAliveTicker := time.NewTicker(keepAliveInterval) updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval) @@ -661,7 +587,7 @@ func (h *Headscale) scheduledPollWorker( return case <-keepAliveTicker.C: - data, err := h.getMapKeepAliveResponse(machineKey, mapRequest) + data, err := h.getMapKeepAliveResponseData(mapRequest, machine, isNoise) if err != nil { log.Error(). Str("func", "keepAlive"). diff --git a/protocol_legacy_poll.go b/protocol_legacy_poll.go new file mode 100644 index 0000000..a42f399 --- /dev/null +++ b/protocol_legacy_poll.go @@ -0,0 +1,94 @@ +package headscale + +import ( + "errors" + "io" + "net/http" + + "github.com/gorilla/mux" + "github.com/rs/zerolog/log" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// PollNetMapHandler takes care of /machine/:id/map +// +// This is the busiest endpoint, as it keeps the HTTP long poll that updates +// the clients when something in the network changes. +// +// The clients POST stuff like HostInfo and their Endpoints here, but +// only after their first request (marked with the ReadOnly field). +// +// At this moment the updates are sent in a quite horrendous way, but they kinda work. +func (h *Headscale) PollNetMapHandler( + writer http.ResponseWriter, + req *http.Request, +) { + vars := mux.Vars(req) + machineKeyStr, ok := vars["mkey"] + if !ok || machineKeyStr == "" { + log.Error(). + Str("handler", "PollNetMap"). + Msg("No machine key in request") + http.Error(writer, "No machine key in request", http.StatusBadRequest) + + return + } + log.Trace(). + Str("handler", "PollNetMap"). + Str("id", machineKeyStr). + Msg("PollNetMapHandler called") + body, _ := io.ReadAll(req.Body) + + var machineKey key.MachinePublic + err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) + if err != nil { + log.Error(). + Str("handler", "PollNetMap"). + Err(err). + Msg("Cannot parse client key") + + http.Error(writer, "Cannot parse client key", http.StatusBadRequest) + + return + } + mapRequest := tailcfg.MapRequest{} + err = decode(body, &mapRequest, &machineKey, h.privateKey) + if err != nil { + log.Error(). + Str("handler", "PollNetMap"). + Err(err). + Msg("Cannot decode message") + http.Error(writer, "Cannot decode message", http.StatusBadRequest) + + return + } + + machine, err := h.GetMachineByMachineKey(machineKey) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + log.Warn(). + Str("handler", "PollNetMap"). + Msgf("Ignoring request, cannot find machine with key %s", machineKey.String()) + + http.Error(writer, "", http.StatusUnauthorized) + + return + } + log.Error(). + Str("handler", "PollNetMap"). + Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String()) + http.Error(writer, "", http.StatusInternalServerError) + + return + } + + log.Trace(). + Str("handler", "PollNetMap"). + Str("id", machineKeyStr). + Str("machine", machine.Hostname). + Msg("Found machine in database") + + h.handlePollCommon(writer, req, machine, mapRequest, false) +}