diff --git a/hscontrol/app.go b/hscontrol/app.go index e72aca2..b8eb6f6 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -70,7 +70,7 @@ var ( const ( AuthPrefix = "Bearer " - updateInterval = 5000 + updateInterval = 5 * time.Second privateKeyFileMode = 0o600 headscaleDirPerm = 0o700 @@ -219,64 +219,75 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { // deleteExpireEphemeralNodes deletes ephemeral node records that have not been // seen for longer than h.cfg.EphemeralNodeInactivityTimeout. -func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) { - ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) +func (h *Headscale) deleteExpireEphemeralNodes(ctx context.Context, every time.Duration) { + ticker := time.NewTicker(every) - for range ticker.C { - var removed []types.NodeID - var changed []types.NodeID - if err := h.db.Write(func(tx *gorm.DB) error { - removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout) + for { + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + var removed []types.NodeID + var changed []types.NodeID + if err := h.db.Write(func(tx *gorm.DB) error { + removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout) - return nil - }); err != nil { - log.Error().Err(err).Msg("database error while expiring ephemeral nodes") - continue - } + return nil + }); err != nil { + log.Error().Err(err).Msg("database error while expiring ephemeral nodes") + continue + } - if removed != nil { - ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: removed, - }) - } + if removed != nil { + ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") + h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: removed, + }) + } - if changed != nil { - ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: changed, - }) + if changed != nil { + ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") + h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: changed, + }) + } } } } -// expireExpiredMachines expires nodes that have an explicit expiry set +// expireExpiredNodes expires nodes that have an explicit expiry set // after that expiry time has passed. -func (h *Headscale) expireExpiredMachines(intervalMs int64) { - interval := time.Duration(intervalMs) * time.Millisecond - ticker := time.NewTicker(interval) +func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) { + ticker := time.NewTicker(every) lastCheck := time.Unix(0, 0) var update types.StateUpdate var changed bool - for range ticker.C { - if err := h.db.Write(func(tx *gorm.DB) error { - lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck) + for { + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + if err := h.db.Write(func(tx *gorm.DB) error { + lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck) - return nil - }); err != nil { - log.Error().Err(err).Msg("database error while expiring nodes") - continue - } + return nil + }); err != nil { + log.Error().Err(err).Msg("database error while expiring nodes") + continue + } - if changed { - log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes") + if changed { + log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes") - ctx := types.NotifyCtx(context.Background(), "expire-expired", "na") - h.nodeNotifier.NotifyAll(ctx, update) + ctx := types.NotifyCtx(context.Background(), "expire-expired", "na") + h.nodeNotifier.NotifyAll(ctx, update) + } } } } @@ -538,10 +549,13 @@ func (h *Headscale) Serve() error { return errEmptyInitialDERPMap } - // TODO(kradalby): These should have cancel channels and be cleaned - // up on shutdown. - go h.deleteExpireEphemeralNodes(updateInterval) - go h.expireExpiredMachines(updateInterval) + expireEphemeralCtx, expireEphemeralCancel := context.WithCancel(context.Background()) + defer expireEphemeralCancel() + go h.deleteExpireEphemeralNodes(expireEphemeralCtx, updateInterval) + + expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background()) + defer expireNodeCancel() + go h.expireExpiredNodes(expireNodeCtx, updateInterval) if zl.GlobalLevel() == zl.TraceLevel { zerolog.RespLog = true @@ -805,6 +819,9 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received signal to stop, shutting down gracefully") + expireNodeCancel() + expireEphemeralCancel() + trace("closing map sessions") wg := sync.WaitGroup{} for _, mapSess := range h.mapSessions {