diff --git a/api.go b/api.go index 344878c..845a320 100644 --- a/api.go +++ b/api.go @@ -3,6 +3,7 @@ package headscale import ( "encoding/binary" "encoding/json" + "errors" "fmt" "io" "log" @@ -10,9 +11,9 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" "github.com/klauspost/compress/zstd" "gorm.io/datatypes" + "gorm.io/gorm" "inet.af/netaddr" "tailscale.com/tailcfg" "tailscale.com/wgengine/wgcfg" @@ -80,10 +81,9 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { c.String(http.StatusInternalServerError, ":(") return } - defer db.Close() var m Machine - if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() { + if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { log.Println("New Machine!") m = Machine{ Expiry: &req.Expiry, @@ -209,9 +209,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { log.Printf("Cannot open DB: %s", err) return } - defer db.Close() var m Machine - if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() { + if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { log.Printf("Ignoring request, cannot find machine with key %s", mKey.HexString()) return } diff --git a/app.go b/app.go index 947980d..4a846e2 100644 --- a/app.go +++ b/app.go @@ -112,7 +112,6 @@ func (h *Headscale) expireEphemeralNodesWorker() { log.Printf("Cannot open DB: %s", err) return } - defer db.Close() namespaces, err := h.ListNamespaces() if err != nil { diff --git a/cli.go b/cli.go index 9829ac3..31419f3 100644 --- a/cli.go +++ b/cli.go @@ -4,6 +4,7 @@ import ( "errors" "log" + "gorm.io/gorm" "tailscale.com/wgengine/wgcfg" ) @@ -22,9 +23,8 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err log.Printf("Cannot open DB: %s", err) return nil, err } - defer db.Close() m := Machine{} - if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() { + if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, errors.New("Machine not found") } diff --git a/db.go b/db.go index 5e08d9b..f849d01 100644 --- a/db.go +++ b/db.go @@ -3,9 +3,9 @@ package headscale import ( "errors" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/postgres" // sql driver - _ "github.com/jinzhu/gorm/dialects/sqlite" // sql driver + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" ) const dbVersion = "1" @@ -17,30 +17,49 @@ type KV struct { } func (h *Headscale) initDB() error { - db, err := gorm.Open(h.dbType, h.dbString) + db, err := h.db() if err != nil { return err } if h.dbType == "postgres" { db.Exec("create extension if not exists \"uuid-ossp\";") } - db.AutoMigrate(&Machine{}) - db.AutoMigrate(&KV{}) - db.AutoMigrate(&Namespace{}) - db.AutoMigrate(&PreAuthKey{}) - db.Close() + err = db.AutoMigrate(&Machine{}) + if err != nil { + return err + } + err = db.AutoMigrate(&KV{}) + if err != nil { + return err + } + err = db.AutoMigrate(&Namespace{}) + if err != nil { + return err + } + err = db.AutoMigrate(&PreAuthKey{}) + if err != nil { + return err + } err = h.setValue("db_version", dbVersion) return err } func (h *Headscale) db() (*gorm.DB, error) { - db, err := gorm.Open(h.dbType, h.dbString) + var db *gorm.DB + var err error + switch h.dbType { + case "sqlite3": + db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{}) + case "postgres": + db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{}) + } + if err != nil { return nil, err } if h.dbDebug { - db.LogMode(true) + db.Debug() } return db, nil } @@ -50,9 +69,8 @@ func (h *Headscale) getValue(key string) (string, error) { if err != nil { return "", err } - defer db.Close() var row KV - if db.First(&row, "key = ?", key).RecordNotFound() { + if result := db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", errors.New("not found") } return row.Value, nil @@ -67,7 +85,6 @@ func (h *Headscale) setValue(key string, value string) error { if err != nil { return err } - defer db.Close() _, err = h.getValue(key) if err == nil { db.Model(&kv).Where("key = ?", key).Update("value", value) diff --git a/machine.go b/machine.go index 548ce13..9b15b1a 100644 --- a/machine.go +++ b/machine.go @@ -159,7 +159,6 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) { log.Printf("Cannot open DB: %s", err) return nil, err } - defer db.Close() machines := []Machine{} if err = db.Where("namespace_id = ? AND machine_key <> ? AND registered", diff --git a/namespaces.go b/namespaces.go index 723fd6b..0eedad7 100644 --- a/namespaces.go +++ b/namespaces.go @@ -1,10 +1,11 @@ package headscale import ( + "errors" "log" "time" - "github.com/jinzhu/gorm" + "gorm.io/gorm" "tailscale.com/tailcfg" ) @@ -29,7 +30,6 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { log.Printf("Cannot open DB: %s", err) return nil, err } - defer db.Close() n := Namespace{} if err := db.Where("name = ?", name).First(&n).Error; err == nil { @@ -51,7 +51,6 @@ func (h *Headscale) DestroyNamespace(name string) error { log.Printf("Cannot open DB: %s", err) return err } - defer db.Close() n, err := h.GetNamespace(name) if err != nil { @@ -66,8 +65,7 @@ func (h *Headscale) DestroyNamespace(name string) error { return errorNamespaceNotEmpty } - err = db.Unscoped().Delete(&n).Error - if err != nil { + if result := db.Unscoped().Delete(&n); result.Error != nil { return err } @@ -81,10 +79,9 @@ func (h *Headscale) GetNamespace(name string) (*Namespace, error) { log.Printf("Cannot open DB: %s", err) return nil, err } - defer db.Close() n := Namespace{} - if db.First(&n, "name = ?", name).RecordNotFound() { + if result := db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, errorNamespaceNotFound } return &n, nil @@ -97,7 +94,6 @@ func (h *Headscale) ListNamespaces() (*[]Namespace, error) { log.Printf("Cannot open DB: %s", err) return nil, err } - defer db.Close() namespaces := []Namespace{} if err := db.Find(&namespaces).Error; err != nil { return nil, err @@ -116,7 +112,6 @@ func (h *Headscale) ListMachinesInNamespace(name string) (*[]Machine, error) { log.Printf("Cannot open DB: %s", err) return nil, err } - defer db.Close() machines := []Machine{} if err := db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { @@ -136,7 +131,6 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error log.Printf("Cannot open DB: %s", err) return err } - defer db.Close() m.NamespaceID = n.ID db.Save(&m) return nil diff --git a/preauth_keys.go b/preauth_keys.go index 2ec3df7..f0346e6 100644 --- a/preauth_keys.go +++ b/preauth_keys.go @@ -3,8 +3,11 @@ package headscale import ( "crypto/rand" "encoding/hex" + "errors" "log" "time" + + "gorm.io/gorm" ) const errorAuthKeyNotFound = Error("AuthKey not found") @@ -36,7 +39,6 @@ func (h *Headscale) CreatePreAuthKey(namespaceName string, reusable bool, epheme log.Printf("Cannot open DB: %s", err) return nil, err } - defer db.Close() now := time.Now().UTC() kstr, err := h.generateKey() @@ -69,7 +71,6 @@ func (h *Headscale) GetPreAuthKeys(namespaceName string) (*[]PreAuthKey, error) log.Printf("Cannot open DB: %s", err) return nil, err } - defer db.Close() keys := []PreAuthKey{} if err := db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil { @@ -85,10 +86,9 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { if err != nil { return nil, err } - defer db.Close() pak := PreAuthKey{} - if db.Preload("Namespace").First(&pak, "key = ?", k).RecordNotFound() { + if result := db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, errorAuthKeyNotFound } diff --git a/routes.go b/routes.go index 59c32bb..8b09e3f 100644 --- a/routes.go +++ b/routes.go @@ -51,7 +51,6 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest m.EnabledRoutes = datatypes.JSON(routes) db.Save(&m) - db.Close() // THIS IS COMPLETELY USELESS. // The peers map is stored in memory in the server process. diff --git a/utils.go b/utils.go index 55310c8..eff20e2 100644 --- a/utils.go +++ b/utils.go @@ -18,6 +18,7 @@ import ( mathrand "math/rand" "golang.org/x/crypto/nacl/box" + "gorm.io/gorm" "tailscale.com/wgengine/wgcfg" ) @@ -81,7 +82,6 @@ func (h *Headscale) getAvailableIP() (*net.IP, error) { if err != nil { return nil, err } - defer db.Close() i := 0 for { ip, err := getRandomIP() @@ -89,7 +89,7 @@ func (h *Headscale) getAvailableIP() (*net.IP, error) { return nil, err } m := Machine{} - if db.First(&m, "ip_address = ?", ip.String()).RecordNotFound() { + if result := db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) { return ip, nil } i++