diff --git a/app.go b/app.go index d7d5ea5..a0a568f 100644 --- a/app.go +++ b/app.go @@ -94,7 +94,8 @@ type Headscale struct { ipAllocationMutex sync.Mutex - shutdownChan chan struct{} + shutdownChan chan struct{} + pollNetMapStreamWG sync.WaitGroup } // Look up the TLS constant relative to user-supplied TLS client @@ -147,12 +148,13 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { ) app := Headscale{ - cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, - privateKey: privKey, - aclRules: tailcfg.FilterAllowAll, // default allowall - registrationCache: registrationCache, + cfg: cfg, + dbType: cfg.DBtype, + dbString: dbString, + privateKey: privKey, + aclRules: tailcfg.FilterAllowAll, // default allowall + registrationCache: registrationCache, + pollNetMapStreamWG: sync.WaitGroup{}, } err = app.initDB() @@ -565,6 +567,8 @@ func (h *Headscale) Serve() error { // https://github.com/soheilhy/cmux/issues/68 // https://github.com/soheilhy/cmux/issues/91 + var grpcServer *grpc.Server + var grpcListener net.Listener if tlsConfig != nil || h.cfg.GRPCAllowInsecure { log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr) @@ -585,12 +589,12 @@ func (h *Headscale) Serve() error { log.Warn().Msg("gRPC is running without security") } - grpcServer := grpc.NewServer(grpcOptions...) + grpcServer = grpc.NewServer(grpcOptions...) v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h)) reflection.Register(grpcServer) - grpcListener, err := net.Listen("tcp", h.cfg.GRPCAddr) + grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr) if err != nil { return fmt.Errorf("failed to bind to TCP address: %w", err) } @@ -666,7 +670,7 @@ func (h *Headscale) Serve() error { syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGHUP) - go func(c chan os.Signal) { + sigFunc := func(c chan os.Signal) { // Wait for a SIGINT or SIGKILL: for { sig := <-c @@ -676,7 +680,7 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received SIGHUP, reloading ACL and Config") - // TODO(kradalby): Reload config on SIGHUP + // TODO(kradalby): Reload config on SIGHUP if h.cfg.ACL.PolicyPath != "" { aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) @@ -696,7 +700,8 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received signal to stop, shutting down gracefully") - h.shutdownChan <- struct{}{} + close(h.shutdownChan) + h.pollNetMapStreamWG.Wait() // Gracefully shut down servers ctx, cancel := context.WithTimeout(context.Background(), HTTPShutdownTimeout) @@ -708,6 +713,11 @@ func (h *Headscale) Serve() error { } grpcSocket.GracefulStop() + if grpcServer != nil { + grpcServer.GracefulStop() + grpcListener.Close() + } + // Close network listeners promHTTPListener.Close() httpListener.Close() @@ -734,7 +744,12 @@ func (h *Headscale) Serve() error { os.Exit(0) } } - }(sigc) + } + errorGroup.Go(func() error { + sigFunc(sigc) + + return nil + }) return errorGroup.Wait() } @@ -758,13 +773,13 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { } switch h.cfg.TLS.LetsEncrypt.ChallengeType { - case "TLS-ALPN-01": + case tlsALPN01ChallengeType: // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) // The RFC requires that the validation is done on port 443; in other words, headscale // must be reachable on port 443. return certManager.TLSConfig(), nil - case "HTTP-01": + case http01ChallengeType: // Configuration via autocert with HTTP-01. This requires listening on // port 80 for the certificate validation in addition to the headscale // service, which can be configured to run on any other port. diff --git a/config.go b/config.go index 6789f6f..6935840 100644 --- a/config.go +++ b/config.go @@ -18,6 +18,11 @@ import ( "tailscale.com/types/dnstype" ) +const ( + tlsALPN01ChallengeType = "TLS-ALPN-01" + http01ChallengeType = "HTTP-01" +) + // Config contains the initial Headscale configuration. type Config struct { ServerURL string @@ -136,7 +141,7 @@ func LoadConfig(path string, isFile bool) error { viper.AutomaticEnv() viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") - viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01") + viper.SetDefault("tls_letsencrypt_challenge_type", http01ChallengeType) viper.SetDefault("tls_client_auth_mode", "relaxed") viper.SetDefault("log_level", "info") @@ -179,15 +184,15 @@ func LoadConfig(path string, isFile bool) error { } if (viper.GetString("tls_letsencrypt_hostname") != "") && - (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && + (viper.GetString("tls_letsencrypt_challenge_type") == tlsALPN01ChallengeType) && (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) { // this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule) log.Warn(). Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443") } - if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && - (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") { + if (viper.GetString("tls_letsencrypt_challenge_type") != http01ChallengeType) && + (viper.GetString("tls_letsencrypt_challenge_type") != tlsALPN01ChallengeType) { errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n" } diff --git a/poll.go b/poll.go index 6628a17..9c17b5c 100644 --- a/poll.go +++ b/poll.go @@ -290,6 +290,9 @@ func (h *Headscale) PollNetMapStream( keepAliveChan chan []byte, updateChan chan struct{}, ) { + h.pollNetMapStreamWG.Add(1) + defer h.pollNetMapStreamWG.Done() + ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname) ctx, cancel := context.WithCancel(ctx)