Merge pull request #601 from kradalby/signals-reload-acl

This commit is contained in:
Kristoffer Dalby 2022-06-03 10:48:43 +02:00 committed by GitHub
commit 0797148076
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 100 additions and 39 deletions

View file

@ -20,6 +20,7 @@
- Add option to enable/disable logtail (Tailscale's logging infrastructure) [#596](https://github.com/juanfont/headscale/pull/596) - Add option to enable/disable logtail (Tailscale's logging infrastructure) [#596](https://github.com/juanfont/headscale/pull/596)
- This change disables the logs by default - This change disables the logs by default
- Use [Prometheus]'s duration parser, supporting days (`d`), weeks (`w`) and years (`y`) [#598](https://github.com/juanfont/headscale/pull/598) - Use [Prometheus]'s duration parser, supporting days (`d`), weeks (`w`) and years (`y`) [#598](https://github.com/juanfont/headscale/pull/598)
- Add support for reloading ACLs with SIGHUP [#601](https://github.com/juanfont/headscale/pull/601)
## 0.15.0 (2022-03-20) ## 0.15.0 (2022-03-20)

74
app.go
View file

@ -116,6 +116,8 @@ type Config struct {
LogTail LogTailConfig LogTail LogTailConfig
CLI CLIConfig CLI CLIConfig
ACL ACLConfig
} }
type OIDCConfig struct { type OIDCConfig struct {
@ -152,6 +154,10 @@ type CLIConfig struct {
Insecure bool Insecure bool
} }
type ACLConfig struct {
PolicyPath string
}
// Headscale represents the base app of the service. // Headscale represents the base app of the service.
type Headscale struct { type Headscale struct {
cfg Config cfg Config
@ -568,19 +574,6 @@ func (h *Headscale) Serve() error {
return fmt.Errorf("failed change permission of gRPC socket: %w", err) return fmt.Errorf("failed change permission of gRPC socket: %w", err)
} }
// Handle common process-killing signals so we can gracefully shut down:
sigc := make(chan os.Signal, 1)
signal.Notify(sigc, os.Interrupt, syscall.SIGTERM)
go func(c chan os.Signal) {
// Wait for a SIGINT or SIGKILL:
sig := <-c
log.Printf("Caught signal %s: shutting down.", sig)
// Stop listening (and unlink the socket if unix type):
socketListener.Close()
// And we're done:
os.Exit(0)
}(sigc)
grpcGatewayMux := runtime.NewServeMux() grpcGatewayMux := runtime.NewServeMux()
// Make the grpc-gateway connect to grpc over socket // Make the grpc-gateway connect to grpc over socket
@ -725,6 +718,61 @@ func (h *Headscale) Serve() error {
log.Info(). log.Info().
Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr) Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr)
// Handle common process-killing signals so we can gracefully shut down:
sigc := make(chan os.Signal, 1)
signal.Notify(sigc,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
syscall.SIGHUP)
go func(c chan os.Signal) {
// Wait for a SIGINT or SIGKILL:
for {
sig := <-c
switch sig {
case syscall.SIGHUP:
log.Info().
Str("signal", sig.String()).
Msg("Received SIGHUP, reloading ACL and Config")
// TODO(kradalby): Reload config on SIGHUP
if h.cfg.ACL.PolicyPath != "" {
aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
err := h.LoadACLPolicy(aclPath)
if err != nil {
log.Error().Err(err).Msg("Failed to reload ACL policy")
}
log.Info().
Str("path", aclPath).
Msg("ACL policy successfully reloaded")
}
default:
log.Info().
Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully")
// Gracefully shut down servers
promHTTPServer.Shutdown(ctx)
httpServer.Shutdown(ctx)
grpcSocket.GracefulStop()
// Close network listeners
promHTTPListener.Close()
httpListener.Close()
grpcGatewayConn.Close()
// Stop listening (and unlink the socket if unix type):
socketListener.Close()
// And we're done:
os.Exit(0)
}
}
}(sigc)
return errorGroup.Wait() return errorGroup.Wait()
} }

View file

@ -9,7 +9,6 @@ import (
"io/fs" "io/fs"
"net/url" "net/url"
"os" "os"
"path/filepath"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
@ -169,7 +168,7 @@ func GetDERPConfig() headscale.DERPConfig {
} }
} }
func GetLogConfig() headscale.LogTailConfig { func GetLogTailConfig() headscale.LogTailConfig {
enabled := viper.GetBool("logtail.enabled") enabled := viper.GetBool("logtail.enabled")
return headscale.LogTailConfig{ return headscale.LogTailConfig{
@ -177,6 +176,14 @@ func GetLogConfig() headscale.LogTailConfig {
} }
} }
func GetACLConfig() headscale.ACLConfig {
policyPath := viper.GetString("acl_policy_path")
return headscale.ACLConfig{
PolicyPath: policyPath,
}
}
func GetDNSConfig() (*tailcfg.DNSConfig, string) { func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config") { if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{} dnsConfig := &tailcfg.DNSConfig{}
@ -264,23 +271,10 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
return nil, "" return nil, ""
} }
func absPath(path string) string { func GetHeadscaleConfig() headscale.Config {
// If a relative path is provided, prefix it with the the directory where
// the config file was found.
if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
dir, _ := filepath.Split(viper.ConfigFileUsed())
if dir != "" {
path = filepath.Join(dir, path)
}
}
return path
}
func getHeadscaleConfig() headscale.Config {
dnsConfig, baseDomain := GetDNSConfig() dnsConfig, baseDomain := GetDNSConfig()
derpConfig := GetDERPConfig() derpConfig := GetDERPConfig()
logConfig := GetLogConfig() logConfig := GetLogTailConfig()
configuredPrefixes := viper.GetStringSlice("ip_prefixes") configuredPrefixes := viper.GetStringSlice("ip_prefixes")
parsedPrefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes)+1) parsedPrefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes)+1)
@ -342,7 +336,7 @@ func getHeadscaleConfig() headscale.Config {
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"), GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
IPPrefixes: prefixes, IPPrefixes: prefixes,
PrivateKeyPath: absPath(viper.GetString("private_key_path")), PrivateKeyPath: headscale.AbsolutePathFromConfigPath(viper.GetString("private_key_path")),
BaseDomain: baseDomain, BaseDomain: baseDomain,
DERP: derpConfig, DERP: derpConfig,
@ -352,7 +346,7 @@ func getHeadscaleConfig() headscale.Config {
), ),
DBtype: viper.GetString("db_type"), DBtype: viper.GetString("db_type"),
DBpath: absPath(viper.GetString("db_path")), DBpath: headscale.AbsolutePathFromConfigPath(viper.GetString("db_path")),
DBhost: viper.GetString("db_host"), DBhost: viper.GetString("db_host"),
DBport: viper.GetInt("db_port"), DBport: viper.GetInt("db_port"),
DBname: viper.GetString("db_name"), DBname: viper.GetString("db_name"),
@ -361,13 +355,13 @@ func getHeadscaleConfig() headscale.Config {
TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"), TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"), TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
TLSLetsEncryptCacheDir: absPath( TLSLetsEncryptCacheDir: headscale.AbsolutePathFromConfigPath(
viper.GetString("tls_letsencrypt_cache_dir"), viper.GetString("tls_letsencrypt_cache_dir"),
), ),
TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"), TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
TLSCertPath: absPath(viper.GetString("tls_cert_path")), TLSCertPath: headscale.AbsolutePathFromConfigPath(viper.GetString("tls_cert_path")),
TLSKeyPath: absPath(viper.GetString("tls_key_path")), TLSKeyPath: headscale.AbsolutePathFromConfigPath(viper.GetString("tls_key_path")),
TLSClientAuthMode: tlsClientAuthMode, TLSClientAuthMode: tlsClientAuthMode,
DNSConfig: dnsConfig, DNSConfig: dnsConfig,
@ -397,6 +391,8 @@ func getHeadscaleConfig() headscale.Config {
Timeout: viper.GetDuration("cli.timeout"), Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"), Insecure: viper.GetBool("cli.insecure"),
}, },
ACL: GetACLConfig(),
} }
} }
@ -416,7 +412,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
return nil, err return nil, err
} }
cfg := getHeadscaleConfig() cfg := GetHeadscaleConfig()
app, err := headscale.NewHeadscale(cfg) app, err := headscale.NewHeadscale(cfg)
if err != nil { if err != nil {
@ -425,8 +421,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
// We are doing this here, as in the future could be cool to have it also hot-reload // We are doing this here, as in the future could be cool to have it also hot-reload
if viper.GetString("acl_policy_path") != "" { if cfg.ACL.PolicyPath != "" {
aclPath := absPath(viper.GetString("acl_policy_path")) aclPath := headscale.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath)
err = app.LoadACLPolicy(aclPath) err = app.LoadACLPolicy(aclPath)
if err != nil { if err != nil {
log.Fatal(). log.Fatal().
@ -440,7 +436,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
} }
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
cfg := getHeadscaleConfig() cfg := GetHeadscaleConfig()
log.Debug(). log.Debug().
Dur("timeout", cfg.CLI.Timeout). Dur("timeout", cfg.CLI.Timeout).

View file

@ -12,10 +12,13 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"os"
"path/filepath"
"reflect" "reflect"
"strings" "strings"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/viper"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -334,3 +337,16 @@ func IsStringInSlice(slice []string, str string) bool {
return false return false
} }
func AbsolutePathFromConfigPath(path string) string {
// If a relative path is provided, prefix it with the the directory where
// the config file was found.
if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
dir, _ := filepath.Split(viper.ConfigFileUsed())
if dir != "" {
path = filepath.Join(dir, path)
}
}
return path
}