diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 17a9dcb..20f49af 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -26,7 +26,7 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' uses: golangci/golangci-lint-action@v2 with: - version: v1.46.1 + version: v1.49.0 # Only block PRs on new problems. # If this is not enabled, we will end up having PRs diff --git a/acls_test.go b/acls_test.go index db04ee3..fc0f84e 100644 --- a/acls_test.go +++ b/acls_test.go @@ -825,7 +825,6 @@ func Test_listMachinesInNamespace(t *testing.T) { } } -// nolint func Test_expandAlias(t *testing.T) { type args struct { machines []Machine diff --git a/api.go b/api.go index 18ac72f..f5de503 100644 --- a/api.go +++ b/api.go @@ -52,7 +52,7 @@ func (h *Headscale) HealthHandler( } } - if err := h.pingDB(); err != nil { + if err := h.pingDB(req.Context()); err != nil { respond(err) return diff --git a/app.go b/app.go index 6e37fcd..59101be 100644 --- a/app.go +++ b/app.go @@ -18,7 +18,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/patrickmn/go-cache" @@ -601,7 +601,7 @@ func (h *Headscale) Serve() error { grpcOptions := []grpc.ServerOption{ grpc.UnaryInterceptor( - grpc_middleware.ChainUnaryServer( + grpcMiddleware.ChainUnaryServer( h.grpcAuthenticationInterceptor, zerolog.NewUnaryServerInterceptor(), ), @@ -820,10 +820,19 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { // Configuration via autocert with HTTP-01. This requires listening on // port 80 for the certificate validation in addition to the headscale // service, which can be configured to run on any other port. + + server := &http.Server{ + Addr: h.cfg.TLS.LetsEncrypt.Listen, + Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)), + ReadTimeout: HTTPReadTimeout, + } + + err := server.ListenAndServe() + go func() { log.Fatal(). Caller(). - Err(http.ListenAndServe(h.cfg.TLS.LetsEncrypt.Listen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))). + Err(err). Msg("failed to set up a HTTP server") }() @@ -860,19 +869,17 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { } } -func (h *Headscale) setLastStateChangeToNow(namespaces ...string) { +func (h *Headscale) setLastStateChangeToNow() { var err error now := time.Now().UTC() - if len(namespaces) == 0 { - namespaces, err = h.ListNamespacesStr() - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("failed to fetch all namespaces, failing to update last changed state.") - } + namespaces, err := h.ListNamespacesStr() + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("failed to fetch all namespaces, failing to update last changed state.") } for _, namespace := range namespaces { diff --git a/config.go b/config.go index 678b575..14350b7 100644 --- a/config.go +++ b/config.go @@ -5,12 +5,11 @@ import ( "errors" "fmt" "io/fs" + "net/netip" "net/url" "strings" "time" - "net/netip" - "github.com/coreos/go-oidc/v3/oidc" "github.com/rs/zerolog" "github.com/rs/zerolog/log" diff --git a/db.go b/db.go index 17df384..a1a4ef3 100644 --- a/db.go +++ b/db.go @@ -221,8 +221,8 @@ func (h *Headscale) setValue(key string, value string) error { return nil } -func (h *Headscale) pingDB() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) +func (h *Headscale) pingDB(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() db, err := h.db.DB() if err != nil { diff --git a/derp.go b/derp.go index 6a153d2..c3d100b 100644 --- a/derp.go +++ b/derp.go @@ -34,7 +34,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout) defer cancel() - req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil) if err != nil { return nil, err } diff --git a/derp_server.go b/derp_server.go index dbdbc7a..6fe897b 100644 --- a/derp_server.go +++ b/derp_server.go @@ -154,7 +154,7 @@ func (h *Headscale) DERPHandler( if !fastStart { pubKey := h.privateKey.Public() - pubKeyStr := pubKey.UntypedHexString() // nolint + pubKeyStr := pubKey.UntypedHexString() //nolint fmt.Fprintf(conn, "HTTP/1.1 101 Switching Protocols\r\n"+ "Upgrade: DERP\r\n"+ "Connection: Upgrade\r\n"+ @@ -174,7 +174,7 @@ func (h *Headscale) DERPProbeHandler( req *http.Request, ) { switch req.Method { - case "HEAD", "GET": + case http.MethodHead, http.MethodGet: writer.Header().Set("Access-Control-Allow-Origin", "*") writer.WriteHeader(http.StatusOK) default: @@ -202,7 +202,7 @@ func (h *Headscale) DERPBootstrapDNSHandler( ) { dnsEntries := make(map[string][]net.IP) - resolvCtx, cancel := context.WithTimeout(context.Background(), time.Minute) + resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute) defer cancel() var resolver net.Resolver for _, region := range h.DERPMap.Regions { diff --git a/machine_test.go b/machine_test.go index cadd0df..e5ef19b 100644 --- a/machine_test.go +++ b/machine_test.go @@ -540,7 +540,6 @@ func Test_getTags(t *testing.T) { } } -// nolint func Test_getFilteredByACLPeers(t *testing.T) { type args struct { machines []Machine diff --git a/noise.go b/noise.go index c8e6674..45bff7b 100644 --- a/noise.go +++ b/noise.go @@ -31,7 +31,9 @@ func (h *Headscale) NoiseUpgradeHandler( return } - server := http.Server{} + server := http.Server{ + ReadTimeout: HTTPReadTimeout, + } server.Handler = h2c.NewHandler(h.noiseMux, &http2.Server{}) err = server.Serve(netutil.NewOneConnListener(noiseConn, nil)) if err != nil { diff --git a/oidc.go b/oidc.go index 60d531e..f0af600 100644 --- a/oidc.go +++ b/oidc.go @@ -148,12 +148,12 @@ func (h *Headscale) OIDCCallback( return } - rawIDToken, err := h.getIDTokenForOIDCCallback(writer, code, state) + rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state) if err != nil { return } - idToken, err := h.verifyIDTokenForOIDCCallback(writer, rawIDToken) + idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken) if err != nil { return } @@ -240,10 +240,11 @@ func validateOIDCCallbackParams( } func (h *Headscale) getIDTokenForOIDCCallback( + ctx context.Context, writer http.ResponseWriter, code, state string, ) (string, error) { - oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) + oauth2Token, err := h.oauth2Config.Exchange(ctx, code) if err != nil { log.Error(). Err(err). @@ -287,11 +288,12 @@ func (h *Headscale) getIDTokenForOIDCCallback( } func (h *Headscale) verifyIDTokenForOIDCCallback( + ctx context.Context, writer http.ResponseWriter, rawIDToken string, ) (*oidc.IDToken, error) { verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) - idToken, err := verifier.Verify(context.Background(), rawIDToken) + idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { log.Error(). Err(err). diff --git a/protocol_common.go b/protocol_common.go index 154c14c..a1ff06e 100644 --- a/protocol_common.go +++ b/protocol_common.go @@ -105,7 +105,7 @@ func (h *Headscale) handleRegisterCommon( if errors.Is(err, gorm.ErrRecordNotFound) { // If the machine has AuthKey set, handle registration via PreAuthKeys if registerRequest.Auth.AuthKey != "" { - h.handleAuthKeyCommon(writer, req, registerRequest, machineKey) + h.handleAuthKeyCommon(writer, registerRequest, machineKey) return } @@ -134,7 +134,7 @@ func (h *Headscale) handleRegisterCommon( case <-req.Context().Done(): return case <-ticker.C: - h.handleNewMachineCommon(writer, req, registerRequest, machineKey) + h.handleNewMachineCommon(writer, registerRequest, machineKey) return } @@ -190,7 +190,7 @@ func (h *Headscale) handleRegisterCommon( registerCacheExpiration, ) - h.handleNewMachineCommon(writer, req, registerRequest, machineKey) + h.handleNewMachineCommon(writer, registerRequest, machineKey) return } @@ -207,7 +207,7 @@ func (h *Headscale) handleRegisterCommon( // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 if !registerRequest.Expiry.IsZero() && registerRequest.Expiry.UTC().Before(now) { - h.handleMachineLogOutCommon(writer, req, *machine, machineKey) + h.handleMachineLogOutCommon(writer, *machine, machineKey) return } @@ -215,7 +215,7 @@ func (h *Headscale) handleRegisterCommon( // If machine is not expired, and is register, we have a already accepted this machine, // let it proceed with a valid registration if !machine.isExpired() { - h.handleMachineValidRegistrationCommon(writer, req, *machine, machineKey) + h.handleMachineValidRegistrationCommon(writer, *machine, machineKey) return } @@ -226,7 +226,6 @@ func (h *Headscale) handleRegisterCommon( !machine.isExpired() { h.handleMachineRefreshKeyCommon( writer, - req, registerRequest, *machine, machineKey, @@ -236,7 +235,7 @@ func (h *Headscale) handleRegisterCommon( } // The machine has expired - h.handleMachineExpiredCommon(writer, req, registerRequest, *machine, machineKey) + h.handleMachineExpiredCommon(writer, registerRequest, *machine, machineKey) machine.Expiry = &time.Time{} h.registrationCache.Set( @@ -256,7 +255,6 @@ func (h *Headscale) handleRegisterCommon( // TODO: check if any locks are needed around IP allocation. func (h *Headscale) handleAuthKeyCommon( writer http.ResponseWriter, - req *http.Request, registerRequest tailcfg.RegisterRequest, machineKey key.MachinePublic, ) { @@ -455,7 +453,6 @@ func (h *Headscale) handleAuthKeyCommon( // for authorizing the machine. This url is then showed to the user by the local Tailscale client. func (h *Headscale) handleNewMachineCommon( writer http.ResponseWriter, - req *http.Request, registerRequest tailcfg.RegisterRequest, machineKey key.MachinePublic, ) { @@ -511,7 +508,6 @@ func (h *Headscale) handleNewMachineCommon( func (h *Headscale) handleMachineLogOutCommon( writer http.ResponseWriter, - req *http.Request, machine Machine, machineKey key.MachinePublic, ) { @@ -570,7 +566,6 @@ func (h *Headscale) handleMachineLogOutCommon( func (h *Headscale) handleMachineValidRegistrationCommon( writer http.ResponseWriter, - req *http.Request, machine Machine, machineKey key.MachinePublic, ) { @@ -624,7 +619,6 @@ func (h *Headscale) handleMachineValidRegistrationCommon( func (h *Headscale) handleMachineRefreshKeyCommon( writer http.ResponseWriter, - req *http.Request, registerRequest tailcfg.RegisterRequest, machine Machine, machineKey key.MachinePublic, @@ -684,7 +678,6 @@ func (h *Headscale) handleMachineRefreshKeyCommon( func (h *Headscale) handleMachineExpiredCommon( writer http.ResponseWriter, - req *http.Request, registerRequest tailcfg.RegisterRequest, machine Machine, machineKey key.MachinePublic, @@ -699,7 +692,7 @@ func (h *Headscale) handleMachineExpiredCommon( Msg("Machine registration has expired. Sending a authurl to register") if registerRequest.Auth.AuthKey != "" { - h.handleAuthKeyCommon(writer, req, registerRequest, machineKey) + h.handleAuthKeyCommon(writer, registerRequest, machineKey) return } diff --git a/protocol_common_poll.go b/protocol_common_poll.go index 65dcb55..6dedfd0 100644 --- a/protocol_common_poll.go +++ b/protocol_common_poll.go @@ -22,7 +22,7 @@ const machineNameContextKey = contextKey("machineName") // managed the poll loop. func (h *Headscale) handlePollCommon( writer http.ResponseWriter, - req *http.Request, + ctx context.Context, machine *Machine, mapRequest tailcfg.MapRequest, isNoise bool, @@ -201,7 +201,7 @@ func (h *Headscale) handlePollCommon( h.pollNetMapStream( writer, - req, + ctx, machine, mapRequest, pollDataChan, @@ -221,7 +221,7 @@ func (h *Headscale) handlePollCommon( // ensuring we communicate updates and data to the connected clients. func (h *Headscale) pollNetMapStream( writer http.ResponseWriter, - req *http.Request, + ctxReq context.Context, machine *Machine, mapRequest tailcfg.MapRequest, pollDataChan chan []byte, @@ -232,7 +232,7 @@ func (h *Headscale) pollNetMapStream( h.pollNetMapStreamWG.Add(1) defer h.pollNetMapStreamWG.Done() - ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname) + ctx := context.WithValue(ctxReq, machineNameContextKey, machine.Hostname) ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/protocol_common_utils.go b/protocol_common_utils.go index 3dc435f..a189418 100644 --- a/protocol_common_utils.go +++ b/protocol_common_utils.go @@ -75,6 +75,8 @@ func (h *Headscale) marshalResponse( Caller(). Err(err). Msg("Cannot marshal response") + + return nil, err } if machineKey.IsZero() { // if Noise diff --git a/protocol_legacy_poll.go b/protocol_legacy_poll.go index f7ef654..f27ee4e 100644 --- a/protocol_legacy_poll.go +++ b/protocol_legacy_poll.go @@ -90,5 +90,5 @@ func (h *Headscale) PollNetMapHandler( Str("machine", machine.Hostname). Msg("A machine is entering polling via the legacy protocol") - h.handlePollCommon(writer, req, machine, mapRequest, false) + h.handlePollCommon(writer, req.Context(), machine, mapRequest, false) } diff --git a/protocol_noise_poll.go b/protocol_noise_poll.go index 8498dcf..b15183c 100644 --- a/protocol_noise_poll.go +++ b/protocol_noise_poll.go @@ -63,5 +63,5 @@ func (h *Headscale) NoisePollNetMapHandler( Str("machine", machine.Hostname). Msg("A machine is entering polling via the Noise protocol") - h.handlePollCommon(writer, req, machine, mapRequest, true) + h.handlePollCommon(writer, req.Context(), machine, mapRequest, true) }