From 37601f6b4d2ef6173d0ea5040fc18eb68daf45e8 Mon Sep 17 00:00:00 2001 From: Ward Vandewege Date: Sun, 25 Apr 2021 11:24:42 -0400 Subject: [PATCH 1/2] Add a very simple test. --- Makefile | 2 +- cmd/headscale/headscale.go | 18 ++++++++--- cmd/headscale/headscale_test.go | 56 +++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 6 deletions(-) create mode 100644 cmd/headscale/headscale_test.go diff --git a/Makefile b/Makefile index b053e29..f6bf5ba 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ build: dev: lint test build test: - go test -coverprofile=coverage.out + @go test -coverprofile=coverage.out ./... coverprofile_func: go tool cover -func=coverage.out diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index 634e85f..f9dcaa1 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -249,11 +249,16 @@ var createPreAuthKeyCmd = &cobra.Command{ }, } -func main() { +func loadConfig(path string) { viper.SetConfigName("config") - viper.AddConfigPath("/etc/headscale/") - viper.AddConfigPath("$HOME/.headscale") - viper.AddConfigPath(".") + if path == "" { + viper.AddConfigPath("/etc/headscale/") + viper.AddConfigPath("$HOME/.headscale") + viper.AddConfigPath(".") + } else { + // For testing + viper.AddConfigPath(path) + } viper.AutomaticEnv() viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") @@ -279,6 +284,10 @@ func main() { if !strings.HasPrefix(viper.GetString("server_url"), "http://") && !strings.HasPrefix(viper.GetString("server_url"), "https://") { log.Fatalf("Fatal config error: server_url must start with https:// or http://") } +} + +func main() { + loadConfig("") headscaleCmd.AddCommand(versionCmd) headscaleCmd.AddCommand(serveCmd) @@ -302,7 +311,6 @@ func main() { fmt.Println(err) os.Exit(-1) } - } func absPath(path string) string { diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go new file mode 100644 index 0000000..c1fa3c0 --- /dev/null +++ b/cmd/headscale/headscale_test.go @@ -0,0 +1,56 @@ +package main + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/spf13/viper" + "gopkg.in/check.v1" +) + +func Test(t *testing.T) { + check.TestingT(t) +} + +var _ = check.Suite(&Suite{}) + +type Suite struct{} + +func (s *Suite) SetUpSuite(c *check.C) { +} + +func (s *Suite) TearDownSuite(c *check.C) { + +} + +func (*Suite) TestConfigLoading(c *check.C) { + tmpDir, err := ioutil.TempDir("", "headscale") + if err != nil { + c.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + path, err := os.Getwd() + if err != nil { + c.Fatal(err) + } + + // Symlink the example config file + err = os.Symlink(filepath.Clean(path+"/../../config.json.example"), filepath.Join(tmpDir, "config.json")) + if err != nil { + c.Fatal(err) + } + + // Load config + loadConfig(tmpDir) + + // Test that config file was interpreted correctly + c.Assert(viper.GetString("server_url"), check.Equals, "http://192.168.1.12:8000") + c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8000") + c.Assert(viper.GetString("derp_map_path"), check.Equals, "derp.yaml") + c.Assert(viper.GetString("db_port"), check.Equals, "5432") + c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") + c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") +} From f5010fd75b0ecf1e127d5edf1efa8f59b89f4806 Mon Sep 17 00:00:00 2001 From: Ward Vandewege Date: Mon, 26 Apr 2021 20:30:06 -0400 Subject: [PATCH 2/2] Add test for our config validation rules. --- cmd/headscale/headscale.go | 25 +++++++++++++----- cmd/headscale/headscale_test.go | 46 +++++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index f9dcaa1..776ffb9 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "io" "log" @@ -249,7 +250,7 @@ var createPreAuthKeyCmd = &cobra.Command{ }, } -func loadConfig(path string) { +func loadConfig(path string) error { viper.SetConfigName("config") if path == "" { viper.AddConfigPath("/etc/headscale/") @@ -266,28 +267,38 @@ func loadConfig(path string) { err := viper.ReadInConfig() if err != nil { - log.Fatalf("Fatal error config file: %s \n", err) + return errors.New(fmt.Sprintf("Fatal error reading config file: %s \n", err)) } + // Collect any validation errors and return them all at once + var errorText string if (viper.GetString("tls_letsencrypt_hostname") != "") && ((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) { - log.Fatalf("Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both") + errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n" } if (viper.GetString("tls_letsencrypt_hostname") != "") && (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) { - log.Fatalf("Fatal config error: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, listen_addr must end in :443") + errorText += "Fatal config error: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, listen_addr must end in :443\n" } if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") { - log.Fatalf("Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01") + errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n" } if !strings.HasPrefix(viper.GetString("server_url"), "http://") && !strings.HasPrefix(viper.GetString("server_url"), "https://") { - log.Fatalf("Fatal config error: server_url must start with https:// or http://") + errorText += "Fatal config error: server_url must start with https:// or http://\n" + } + if errorText != "" { + return errors.New(strings.TrimSuffix(errorText, "\n")) + } else { + return nil } } func main() { - loadConfig("") + err := loadConfig("") + if err != nil { + log.Fatalf(err.Error()) + } headscaleCmd.AddCommand(versionCmd) headscaleCmd.AddCommand(serveCmd) diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index c1fa3c0..a3894f6 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -1,9 +1,11 @@ package main import ( + "fmt" "io/ioutil" "os" "path/filepath" + "strings" "testing" "github.com/spf13/viper" @@ -43,8 +45,9 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Fatal(err) } - // Load config - loadConfig(tmpDir) + // Load example config, it should load without validation errors + err = loadConfig(tmpDir) + c.Assert(err, check.IsNil) // Test that config file was interpreted correctly c.Assert(viper.GetString("server_url"), check.Equals, "http://192.168.1.12:8000") @@ -54,3 +57,42 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") } + +func writeConfig(c *check.C, tmpDir string, configYaml []byte) { + // Populate a custom config file + configFile := filepath.Join(tmpDir, "config.yaml") + err := ioutil.WriteFile(configFile, configYaml, 0644) + if err != nil { + c.Fatalf("Couldn't write file %s", configFile) + } +} + +func (*Suite) TestTLSConfigValidation(c *check.C) { + tmpDir, err := ioutil.TempDir("", "headscale") + if err != nil { + c.Fatal(err) + } + //defer os.RemoveAll(tmpDir) + fmt.Println(tmpDir) + + configYaml := []byte("---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"") + writeConfig(c, tmpDir, configYaml) + + // Check configuration validation errors (1) + err = loadConfig(tmpDir) + c.Assert(err, check.NotNil) + // check.Matches can not handle multiline strings + tmp := strings.ReplaceAll(err.Error(), "\n", "***") + c.Assert(tmp, check.Matches, ".*Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both.*") + c.Assert(tmp, check.Matches, ".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*") + c.Assert(tmp, check.Matches, ".*Fatal config error: server_url must start with https:// or http://.*") + fmt.Println(tmp) + + // Check configuration validation errors (2) + configYaml = []byte("---\nserver_url: \"http://192.168.1.12:8000\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"") + fmt.Printf(string(configYaml)) + writeConfig(c, tmpDir, configYaml) + err = loadConfig(tmpDir) + c.Assert(err, check.NotNil) + c.Assert(err, check.ErrorMatches, "Fatal config error: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, listen_addr must end in :443.*") +}