diff --git a/app.go b/app.go index 8df56dd..26ec956 100644 --- a/app.go +++ b/app.go @@ -171,7 +171,7 @@ func LookupTLSClientAuthMode(mode string) (tls.ClientAuthType, bool) { // Client cert is _required and verified_. return tls.RequireAndVerifyClientCert, true default: - // Return the default when an unknown value is supplied. + // Return the default when an unknown value is supplied. return tls.RequireAnyClientCert, false } } diff --git a/app_test.go b/app_test.go index 3df5948..53c703a 100644 --- a/app_test.go +++ b/app_test.go @@ -69,9 +69,8 @@ func (s *Suite) ResetDB(c *check.C) { // Enusre an error is returned when an invalid auth mode // is supplied. func (s *Suite) TestInvalidClientAuthMode(c *check.C) { - app.cfg.TLSClientAuthMode = "invalid" - _, err := app.GetClientAuthMode() - c.Assert(err, check.NotNil) + _, isValid := LookupTLSClientAuthMode("invalid") + c.Assert(isValid, check.Equals, false) } // Ensure that all client auth modes return a nil error. @@ -79,8 +78,7 @@ func (s *Suite) TestAuthModes(c *check.C) { modes := []string{"disabled", "relaxed", "enforced"} for _, v := range modes { - app.cfg.TLSClientAuthMode = v - _, err := app.GetClientAuthMode() - c.Assert(err, check.IsNil) + _, isValid := LookupTLSClientAuthMode(v) + c.Assert(isValid, check.Equals, true) } } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index dbcc8bb..9316302 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -34,7 +34,6 @@ const ( ) func LoadConfig(path string) error { - viper.SetConfigName("config") if path == "" { viper.AddConfigPath("/etc/headscale/") @@ -98,12 +97,12 @@ func LoadConfig(path string) error { _, authModeValid := headscale.LookupTLSClientAuthMode(viper.GetString("tls_client_auth_mode")) if !authModeValid { - errorText += fmt.Sprintf( - "Invalid tls_client_auth_mode supplied: %s. Accepted values: %s, %s, %s.", - viper.GetString("tls_client_auth_mode"), - headscale.DisabledClientAuth, - headscale.RelaxedClientAuth, - headscale.EnforcedClientAuth) + errorText += fmt.Sprintf( + "Invalid tls_client_auth_mode supplied: %s. Accepted values: %s, %s, %s.", + viper.GetString("tls_client_auth_mode"), + headscale.DisabledClientAuth, + headscale.RelaxedClientAuth, + headscale.EnforcedClientAuth) } if errorText != "" { @@ -295,7 +294,9 @@ func getHeadscaleConfig() headscale.Config { Msgf("'ip_prefixes' not configured, falling back to default: %v", prefixes) } - tlsClientAuthMode, _ := headscale.LookupTLSClientAuthMode(viper.GetString("tls_client_auth_mode")) + tlsClientAuthMode, _ := headscale.LookupTLSClientAuthMode( + viper.GetString("tls_client_auth_mode"), + ) return headscale.Config{ ServerURL: viper.GetString("server_url"),