diff --git a/.gitignore b/.gitignore index abb466e..ff4f666 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ /headscale config.json *.key +/db.sqlite diff --git a/README.md b/README.md index 473eca3..950b57d 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Suggestions/PRs welcomed! make ``` -2. Get yourself a PostgreSQL DB running (yes, [I know](https://tailscale.com/blog/an-unlikely-database-migration/)) +2. (Optional, you can also use SQLite) Get yourself a PostgreSQL DB running ```shell docker run --name headscale -e POSTGRES_DB=headscale -e \ @@ -55,7 +55,12 @@ Suggestions/PRs welcomed! ```shell wg genkey > private.key wg pubkey < private.key > public.key # not needed - cp config.json.example config.json + + # Postgres + cp config.json.postgres.example config.json + # or + # SQLite + cp config.json.sqlite.example config.json ``` 4. Create a namespace (equivalent to a user in tailscale.com) diff --git a/api.go b/api.go index 56d858c..3c9c0ae 100644 --- a/api.go +++ b/api.go @@ -279,12 +279,14 @@ func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgc return default: + h.pollMu.Lock() data, err := h.getMapKeepAliveResponse(mKey, req, m) if err != nil { log.Printf("Error generating the keep alive msg: %s", err) return } pollData <- *data + h.pollMu.Unlock() time.Sleep(60 * time.Second) } } diff --git a/app.go b/app.go index d7fd3f2..917522f 100644 --- a/app.go +++ b/app.go @@ -22,6 +22,8 @@ type Config struct { PrivateKeyPath string DerpMap *tailcfg.DERPMap + DBtype string + DBpath string DBhost string DBport int DBname string @@ -60,11 +62,22 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, err } pubKey := privKey.Public() + + var dbString string + switch cfg.DBtype { + case "postgres": + dbString = fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost, + cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass) + case "sqlite3": + dbString = cfg.DBpath + default: + return nil, errors.New("Unsupported DB") + } + h := Headscale{ - cfg: cfg, - dbType: "postgres", - dbString: fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost, - cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass), + cfg: cfg, + dbType: cfg.DBtype, + dbString: dbString, privateKey: privKey, publicKey: &pubKey, } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index a9897fb..b30e43a 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -22,10 +22,10 @@ type ErrorOutput struct { func absPath(path string) string { // If a relative path is provided, prefix it with the the directory where // the config file was found. - if (path != "") && !strings.HasPrefix(path, "/") { + if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) { dir, _ := filepath.Split(viper.ConfigFileUsed()) if dir != "" { - path = dir + "/" + path + path = filepath.Join(dir, path) } } return path @@ -43,6 +43,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) { PrivateKeyPath: absPath(viper.GetString("private_key_path")), DerpMap: derpMap, + DBtype: viper.GetString("db_type"), + DBpath: absPath(viper.GetString("db_path")), DBhost: viper.GetString("db_host"), DBport: viper.GetInt("db_port"), DBname: viper.GetString("db_name"), diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index baa2a65..ed142cd 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -27,7 +27,7 @@ func (s *Suite) TearDownSuite(c *check.C) { } -func (*Suite) TestConfigLoading(c *check.C) { +func (*Suite) TestPostgresConfigLoading(c *check.C) { tmpDir, err := ioutil.TempDir("", "headscale") if err != nil { c.Fatal(err) @@ -40,7 +40,7 @@ func (*Suite) TestConfigLoading(c *check.C) { } // Symlink the example config file - err = os.Symlink(filepath.Clean(path+"/../../config.json.example"), filepath.Join(tmpDir, "config.json")) + err = os.Symlink(filepath.Clean(path+"/../../config.json.postgres.example"), filepath.Join(tmpDir, "config.json")) if err != nil { c.Fatal(err) } @@ -50,14 +50,47 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Assert(err, check.IsNil) // Test that config file was interpreted correctly - c.Assert(viper.GetString("server_url"), check.Equals, "http://192.168.1.12:8000") + c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8000") c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8000") c.Assert(viper.GetString("derp_map_path"), check.Equals, "derp.yaml") + c.Assert(viper.GetString("db_type"), check.Equals, "postgres") c.Assert(viper.GetString("db_port"), check.Equals, "5432") c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") } +func (*Suite) TestSqliteConfigLoading(c *check.C) { + tmpDir, err := ioutil.TempDir("", "headscale") + if err != nil { + c.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + path, err := os.Getwd() + if err != nil { + c.Fatal(err) + } + + // Symlink the example config file + err = os.Symlink(filepath.Clean(path+"/../../config.json.sqlite.example"), filepath.Join(tmpDir, "config.json")) + if err != nil { + c.Fatal(err) + } + + // Load example config, it should load without validation errors + err = loadConfig(tmpDir) + c.Assert(err, check.IsNil) + + // Test that config file was interpreted correctly + c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8000") + c.Assert(viper.GetString("listen_addr"), check.Equals, "0.0.0.0:8000") + c.Assert(viper.GetString("derp_map_path"), check.Equals, "derp.yaml") + c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") + c.Assert(viper.GetString("db_path"), check.Equals, "db.sqlite") + c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") + c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") +} + func writeConfig(c *check.C, tmpDir string, configYaml []byte) { // Populate a custom config file configFile := filepath.Join(tmpDir, "config.yaml") @@ -89,7 +122,7 @@ func (*Suite) TestTLSConfigValidation(c *check.C) { fmt.Println(tmp) // Check configuration validation errors (2) - configYaml = []byte("---\nserver_url: \"http://192.168.1.12:8000\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"") + configYaml = []byte("---\nserver_url: \"http://127.0.0.1:8000\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"") writeConfig(c, tmpDir, configYaml) err = loadConfig(tmpDir) c.Assert(err, check.NotNil) diff --git a/config.json.example b/config.json.postgres.example similarity index 85% rename from config.json.example rename to config.json.postgres.example index 5811dac..6436ec5 100644 --- a/config.json.example +++ b/config.json.postgres.example @@ -1,8 +1,9 @@ { - "server_url": "http://192.168.1.12:8000", + "server_url": "http://127.0.0.1:8000", "listen_addr": "0.0.0.0:8000", "private_key_path": "private.key", "derp_map_path": "derp.yaml", + "db_type": "postgres", "db_host": "localhost", "db_port": 5432, "db_name": "headscale", diff --git a/config.json.sqlite.example b/config.json.sqlite.example new file mode 100644 index 0000000..2ae821b --- /dev/null +++ b/config.json.sqlite.example @@ -0,0 +1,13 @@ +{ + "server_url": "http://127.0.0.1:8000", + "listen_addr": "0.0.0.0:8000", + "private_key_path": "private.key", + "derp_map_path": "derp.yaml", + "db_type": "sqlite3", + "db_path": "db.sqlite", + "tls_letsencrypt_hostname": "", + "tls_letsencrypt_cache_dir": ".cache", + "tls_letsencrypt_challenge_type": "HTTP-01", + "tls_cert_path": "", + "tls_key_path": "" +} diff --git a/db.go b/db.go index 0723c52..5e08d9b 100644 --- a/db.go +++ b/db.go @@ -5,6 +5,7 @@ import ( "github.com/jinzhu/gorm" _ "github.com/jinzhu/gorm/dialects/postgres" // sql driver + _ "github.com/jinzhu/gorm/dialects/sqlite" // sql driver ) const dbVersion = "1"