diff --git a/poll.go b/poll.go index 15945a9..3bad0b8 100644 --- a/poll.go +++ b/poll.go @@ -175,32 +175,13 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Str("machine", machine.Name). Msg("Loading or creating update channel") - // TODO: could probably remove all that duplication once generics land. - closeChanWithLog := func(channel interface{}, name string) { - log.Trace(). - Str("handler", "PollNetMap"). - Str("machine", machine.Name). - Str("channel", "Done"). - Msg(fmt.Sprintf("Closing %s channel", name)) - - switch c := channel.(type) { - case (chan struct{}): - close(c) - - case (chan []byte): - close(c) - } - } - const chanSize = 8 updateChan := make(chan struct{}, chanSize) - defer closeChanWithLog(updateChan, "updateChan") pollDataChan := make(chan []byte, chanSize) - defer closeChanWithLog(pollDataChan, "pollDataChan") + defer closeChanWithLog(pollDataChan, machine.Name, "pollDataChan") keepAliveChan := make(chan []byte) - defer closeChanWithLog(keepAliveChan, "keepAliveChan") if req.OmitPeers && !req.Stream { log.Info(). @@ -273,7 +254,27 @@ func (h *Headscale) PollNetMapStream( updateChan chan struct{}, ) { { - ctx, cancel := context.WithCancel(ctx.Request.Context()) + machine, err := h.GetMachineByMachineKey(machineKey) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + log.Warn(). + Str("handler", "PollNetMap"). + Msgf("Ignoring request, cannot find machine with key %s", machineKey.String()) + ctx.String(http.StatusUnauthorized, "") + + return + } + log.Error(). + Str("handler", "PollNetMap"). + Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String()) + ctx.String(http.StatusInternalServerError, "") + + return + } + + ctx := context.WithValue(ctx.Request.Context(), "machineName", machine.Name) + + ctx, cancel := context.WithCancel(ctx) defer cancel() go h.scheduledPollWorker( @@ -564,8 +565,8 @@ func (h *Headscale) PollNetMapStream( func (h *Headscale) scheduledPollWorker( ctx context.Context, - updateChan chan<- struct{}, - keepAliveChan chan<- []byte, + updateChan chan struct{}, + keepAliveChan chan []byte, machineKey key.MachinePublic, mapRequest tailcfg.MapRequest, machine *Machine, @@ -573,6 +574,17 @@ func (h *Headscale) scheduledPollWorker( keepAliveTicker := time.NewTicker(keepAliveInterval) updateCheckerTicker := time.NewTicker(updateCheckInterval) + defer closeChanWithLog( + updateChan, + fmt.Sprint(ctx.Value("machineName")), + "updateChan", + ) + defer closeChanWithLog( + keepAliveChan, + fmt.Sprint(ctx.Value("machineName")), + "updateChan", + ) + for { select { case <-ctx.Done(): @@ -606,3 +618,13 @@ func (h *Headscale) scheduledPollWorker( } } } + +func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) { + log.Trace(). + Str("handler", "PollNetMap"). + Str("machine", machine). + Str("channel", "Done"). + Msg(fmt.Sprintf("Closing %s channel", name)) + + close(channel) +}