cli.LoadConfig accepts config file now
This commit is contained in:
parent
adb55bcfe9
commit
0363e58467
3 changed files with 62 additions and 13 deletions
|
@ -33,7 +33,10 @@ const (
|
|||
HeadscaleDateTimeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
func LoadConfig(path string) error {
|
||||
func LoadConfig(path string, isFile bool) error {
|
||||
if isFile {
|
||||
viper.SetConfigFile(path)
|
||||
} else {
|
||||
viper.SetConfigName("config")
|
||||
if path == "" {
|
||||
viper.AddConfigPath("/etc/headscale/")
|
||||
|
@ -43,6 +46,7 @@ func LoadConfig(path string) error {
|
|||
// For testing
|
||||
viper.AddConfigPath(path)
|
||||
}
|
||||
}
|
||||
|
||||
viper.SetEnvPrefix("headscale")
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
|
|
|
@ -43,7 +43,7 @@ func main() {
|
|||
NoColor: !colors,
|
||||
})
|
||||
|
||||
if err := cli.LoadConfig(""); err != nil {
|
||||
if err := cli.LoadConfig("", false); err != nil {
|
||||
log.Fatal().Caller().Err(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,51 @@ func (s *Suite) SetUpSuite(c *check.C) {
|
|||
func (s *Suite) TearDownSuite(c *check.C) {
|
||||
}
|
||||
|
||||
func (*Suite) TestConfigFileLoading(c *check.C) {
|
||||
tmpDir, err := ioutil.TempDir("", "headscale")
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
path, err := os.Getwd()
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
|
||||
cfgFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Symlink the example config file
|
||||
err = os.Symlink(
|
||||
filepath.Clean(path+"/../../config-example.yaml"),
|
||||
cfgFile,
|
||||
)
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
|
||||
// Load example config, it should load without validation errors
|
||||
err = cli.LoadConfig(cfgFile, true)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// Test that config file was interpreted correctly
|
||||
c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080")
|
||||
c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8080")
|
||||
c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090")
|
||||
c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3")
|
||||
c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite")
|
||||
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
||||
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
|
||||
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
||||
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
|
||||
c.Assert(
|
||||
cli.GetFileMode("unix_socket_permission"),
|
||||
check.Equals,
|
||||
fs.FileMode(0o770),
|
||||
)
|
||||
c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false)
|
||||
}
|
||||
|
||||
func (*Suite) TestConfigLoading(c *check.C) {
|
||||
tmpDir, err := ioutil.TempDir("", "headscale")
|
||||
if err != nil {
|
||||
|
@ -49,7 +94,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
|
|||
}
|
||||
|
||||
// Load example config, it should load without validation errors
|
||||
err = cli.LoadConfig(tmpDir)
|
||||
err = cli.LoadConfig(tmpDir, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// Test that config file was interpreted correctly
|
||||
|
@ -92,7 +137,7 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
|
|||
}
|
||||
|
||||
// Load example config, it should load without validation errors
|
||||
err = cli.LoadConfig(tmpDir)
|
||||
err = cli.LoadConfig(tmpDir, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
dnsConfig, baseDomain := cli.GetDNSConfig()
|
||||
|
@ -125,7 +170,7 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
|
|||
writeConfig(c, tmpDir, configYaml)
|
||||
|
||||
// Check configuration validation errors (1)
|
||||
err = cli.LoadConfig(tmpDir)
|
||||
err = cli.LoadConfig(tmpDir, false)
|
||||
c.Assert(err, check.NotNil)
|
||||
// check.Matches can not handle multiline strings
|
||||
tmp := strings.ReplaceAll(err.Error(), "\n", "***")
|
||||
|
@ -150,6 +195,6 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
|
|||
"---\nserver_url: \"http://127.0.0.1:8080\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"",
|
||||
)
|
||||
writeConfig(c, tmpDir, configYaml)
|
||||
err = cli.LoadConfig(tmpDir)
|
||||
err = cli.LoadConfig(tmpDir, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue