From 7fd2485000c743666316f4eaf691967de7030361 Mon Sep 17 00:00:00 2001 From: MichaelKo Date: Thu, 16 May 2024 02:40:14 +0200 Subject: [PATCH] Restore foreign keys and add constraints (#1562) * fix #1482, restore foregin keys, add constraints * #1562, fix tests, fix formatting * #1562: fix tests * #1562: fix local run of test_integration --- CHANGELOG.md | 1 + Makefile | 1 + hscontrol/auth.go | 11 ++++-- hscontrol/db/db.go | 9 +++-- hscontrol/db/ip_test.go | 26 ++++++++++++++ hscontrol/db/node.go | 2 +- hscontrol/db/node_test.go | 57 ++++++++++++++++++++----------- hscontrol/db/preauth_keys.go | 3 +- hscontrol/db/preauth_keys_test.go | 21 ++++++++---- hscontrol/db/routes_test.go | 43 ++++++++++++++++++----- hscontrol/db/users_test.go | 12 ++++--- hscontrol/mapper/mapper_test.go | 7 ++-- hscontrol/mapper/tail_test.go | 1 - hscontrol/types/node.go | 8 ++--- hscontrol/types/preauth_key.go | 8 ++--- 15 files changed, 149 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cd8283..a8e15c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - Add command to backfill IP addresses for nodes missing IPs from configured prefixes. [#1869](https://github.com/juanfont/headscale/pull/1869) - Log available update as warning [#1877](https://github.com/juanfont/headscale/pull/1877) - Add `autogroup:internet` to Policy [#1917](https://github.com/juanfont/headscale/pull/1917) +- Restore foreign keys and add constraints [#1562](https://github.com/juanfont/headscale/pull/1562) ## 0.22.3 (2023-05-12) diff --git a/Makefile b/Makefile index 442690e..719393f 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,7 @@ test_integration: --name headscale-test-suite \ -v $$PWD:$$PWD -w $$PWD/integration \ -v /var/run/docker.sock:/var/run/docker.sock \ + -v $$PWD/control_logs:/tmp/control \ golang:1 \ go run gotest.tools/gotestsum@latest -- -failfast ./... -timeout 120m -parallel 8 diff --git a/hscontrol/auth.go b/hscontrol/auth.go index dab9ff4..c4511db 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -314,7 +314,11 @@ func (h *Headscale) handleAuthKey( Msg("node was already registered before, refreshing with new auth key") node.NodeKey = nodeKey - node.AuthKeyID = uint(pak.ID) + pakID := uint(pak.ID) + if pakID != 0 { + node.AuthKeyID = &pakID + } + node.Expiry = ®isterRequest.Expiry node.User = pak.User node.UserID = pak.UserID @@ -373,7 +377,6 @@ func (h *Headscale) handleAuthKey( Expiry: ®isterRequest.Expiry, NodeKey: nodeKey, LastSeen: &now, - AuthKeyID: uint(pak.ID), ForcedTags: pak.Proto().GetAclTags(), } @@ -389,6 +392,10 @@ func (h *Headscale) handleAuthKey( return } + pakID := uint(pak.ID) + if pakID != 0 { + nodeToRegister.AuthKeyID = &pakID + } node, err = h.db.RegisterNode( nodeToRegister, ipv4, ipv6, diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index c8ec337..a30939c 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -91,7 +91,8 @@ func NewHeadscaleDatabase( _ = tx.Migrator(). RenameColumn(&types.Node{}, "nickname", "given_name") - // If the Node table has a column for registered, + dbConn.Model(&types.Node{}).Where("auth_key_id = ?", 0).Update("auth_key_id", nil) + // If the Node table has a column for registered, // find all occourences of "false" and drop them. Then // remove the column. if tx.Migrator().HasColumn(&types.Node{}, "registered") { @@ -441,8 +442,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { db, err := gorm.Open( sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"), &gorm.Config{ - DisableForeignKeyConstraintWhenMigrating: true, - Logger: dbLogger, + Logger: dbLogger, }, ) @@ -488,8 +488,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { } db, err := gorm.Open(postgres.Open(dbString), &gorm.Config{ - DisableForeignKeyConstraintWhenMigrating: true, - Logger: dbLogger, + Logger: dbLogger, }) if err != nil { return nil, err diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index a651476..c922fcd 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -87,8 +87,11 @@ func TestIPAllocatorSequential(t *testing.T) { name: "simple-with-db", dbFunc: func() *HSDatabase { db := dbForTest(t, "simple-with-db") + user := types.User{Name: ""} + db.DB.Save(&user) db.DB.Save(&types.Node{ + User: user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -112,8 +115,11 @@ func TestIPAllocatorSequential(t *testing.T) { name: "before-after-free-middle-in-db", dbFunc: func() *HSDatabase { db := dbForTest(t, "before-after-free-middle-in-db") + user := types.User{Name: ""} + db.DB.Save(&user) db.DB.Save(&types.Node{ + User: user, IPv4: nap("100.64.0.2"), IPv6: nap("fd7a:115c:a1e0::2"), }) @@ -307,8 +313,11 @@ func TestBackfillIPAddresses(t *testing.T) { name: "simple-backfill-ipv6", dbFunc: func() *HSDatabase { db := dbForTest(t, "simple-backfill-ipv6") + user := types.User{Name: ""} + db.DB.Save(&user) db.DB.Save(&types.Node{ + User: user, IPv4: nap("100.64.0.1"), }) @@ -337,8 +346,11 @@ func TestBackfillIPAddresses(t *testing.T) { name: "simple-backfill-ipv4", dbFunc: func() *HSDatabase { db := dbForTest(t, "simple-backfill-ipv4") + user := types.User{Name: ""} + db.DB.Save(&user) db.DB.Save(&types.Node{ + User: user, IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -367,8 +379,11 @@ func TestBackfillIPAddresses(t *testing.T) { name: "simple-backfill-remove-ipv6", dbFunc: func() *HSDatabase { db := dbForTest(t, "simple-backfill-remove-ipv6") + user := types.User{Name: ""} + db.DB.Save(&user) db.DB.Save(&types.Node{ + User: user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -392,8 +407,11 @@ func TestBackfillIPAddresses(t *testing.T) { name: "simple-backfill-remove-ipv4", dbFunc: func() *HSDatabase { db := dbForTest(t, "simple-backfill-remove-ipv4") + user := types.User{Name: ""} + db.DB.Save(&user) db.DB.Save(&types.Node{ + User: user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -417,17 +435,23 @@ func TestBackfillIPAddresses(t *testing.T) { name: "multi-backfill-ipv6", dbFunc: func() *HSDatabase { db := dbForTest(t, "simple-backfill-ipv6") + user := types.User{Name: ""} + db.DB.Save(&user) db.DB.Save(&types.Node{ + User: user, IPv4: nap("100.64.0.1"), }) db.DB.Save(&types.Node{ + User: user, IPv4: nap("100.64.0.2"), }) db.DB.Save(&types.Node{ + User: user, IPv4: nap("100.64.0.3"), }) db.DB.Save(&types.Node{ + User: user, IPv4: nap("100.64.0.4"), }) @@ -451,6 +475,8 @@ func TestBackfillIPAddresses(t *testing.T) { "MachineKeyDatabaseField", "NodeKeyDatabaseField", "DiscoKeyDatabaseField", + "User", + "UserID", "Endpoints", "HostinfoDatabaseField", "Hostinfo", diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 91bf0cb..e9a4ea0 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -279,7 +279,7 @@ func DeleteNode(tx *gorm.DB, } // Unscoped causes the node to be fully removed from the database. - if err := tx.Unscoped().Delete(&node).Error; err != nil { + if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil { return changed, err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index ce2ada3..fa18765 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -29,6 +29,7 @@ func (s *Suite) TestGetNode(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() + pakID := uint(pak.ID) node := &types.Node{ ID: 0, @@ -37,9 +38,10 @@ func (s *Suite) TestGetNode(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, } - db.DB.Save(node) + trx := db.DB.Save(node) + c.Assert(trx.Error, check.IsNil) _, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) @@ -58,6 +60,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() + pakID := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -65,9 +68,10 @@ func (s *Suite) TestGetNodeByID(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) _, err = db.GetNodeByID(0) c.Assert(err, check.IsNil) @@ -88,6 +92,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { machineKey := key.NewMachine() + pakID := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -95,9 +100,10 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) _, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) c.Assert(err, check.IsNil) @@ -117,9 +123,9 @@ func (s *Suite) TestHardDeleteNode(c *check.C) { Hostname: "testnode3", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(1), } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) _, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]()) c.Assert(err, check.IsNil) @@ -138,6 +144,7 @@ func (s *Suite) TestListPeers(c *check.C) { _, err = db.GetNodeByID(0) c.Assert(err, check.NotNil) + pakID := uint(pak.ID) for index := 0; index <= 10; index++ { nodeKey := key.NewNode() machineKey := key.NewMachine() @@ -149,9 +156,10 @@ func (s *Suite) TestListPeers(c *check.C) { Hostname: "testnode" + strconv.Itoa(index), UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) } node0ByID, err := db.GetNodeByID(0) @@ -188,6 +196,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { for index := 0; index <= 10; index++ { nodeKey := key.NewNode() machineKey := key.NewMachine() + pakID := uint(stor[index%2].key.ID) v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))) node := types.Node{ @@ -198,9 +207,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { Hostname: "testnode" + strconv.Itoa(index), UserID: stor[index%2].user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(stor[index%2].key.ID), + AuthKeyID: &pakID, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) } aclPolicy := &policy.ACLPolicy{ @@ -272,6 +282,7 @@ func (s *Suite) TestExpireNode(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() + pakID := uint(pak.ID) node := &types.Node{ ID: 0, @@ -280,7 +291,7 @@ func (s *Suite) TestExpireNode(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, Expiry: &time.Time{}, } db.DB.Save(node) @@ -316,6 +327,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { machineKey2 := key.NewMachine() + pakID := uint(pak.ID) node := &types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -324,9 +336,11 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { GivenName: "hostname-1", UserID: user1.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, } - db.DB.Save(node) + + trx := db.DB.Save(node) + c.Assert(trx.Error, check.IsNil) givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2") comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") @@ -357,6 +371,7 @@ func (s *Suite) TestSetTags(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() + pakID := uint(pak.ID) node := &types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -364,9 +379,11 @@ func (s *Suite) TestSetTags(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, } - db.DB.Save(node) + + trx := db.DB.Save(node) + c.Assert(trx.Error, check.IsNil) // assign simple tags sTags := []string{"tag:test", "tag:foo"} @@ -548,6 +565,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { route2 := netip.MustParsePrefix("10.11.0.0/24") v4 := netip.MustParseAddr("100.64.0.1") + pakID := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -555,7 +573,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { Hostname: "test", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:exit"}, RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, @@ -563,7 +581,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { IPv4: &v4, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) sendUpdate, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 5d38de2..16a8689 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -197,9 +197,10 @@ func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) { } nodes := types.Nodes{} + pakID := uint(pak.ID) if err := tx. Preload("AuthKey"). - Where(&types.Node{AuthKeyID: uint(pak.ID)}). + Where(&types.Node{AuthKeyID: &pakID}). Find(&nodes).Error; err != nil { return nil, err } diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index fa9681a..9cdcba8 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -76,14 +76,16 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) + pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) @@ -97,14 +99,16 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) + pakID := uint(pak.ID) node := types.Node{ ID: 1, Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) @@ -131,15 +135,17 @@ func (*Suite) TestEphemeralKeyReusable(c *check.C) { c.Assert(err, check.IsNil) now := time.Now().Add(-time.Second * 30) + pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, LastSeen: &now, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) _, err = db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) @@ -165,13 +171,14 @@ func (*Suite) TestEphemeralKeyNotReusable(c *check.C) { c.Assert(err, check.IsNil) now := time.Now().Add(-time.Second * 30) + pakId := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, LastSeen: &now, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } db.DB.Save(&node) diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 02342ca..8bbc594 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -43,15 +43,17 @@ func (s *Suite) TestGetRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route}, } + pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "test_get_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, Hostinfo: &hostInfo, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) su, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -93,15 +95,17 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route, route2}, } + pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, Hostinfo: &hostInfo, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) sendUpdate, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -165,15 +169,17 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { hostInfo1 := tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{route, route2}, } + pakID := uint(pak.ID) node1 := types.Node{ ID: 1, Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, Hostinfo: &hostInfo1, } - db.DB.Save(&node1) + trx := db.DB.Save(&node1) + c.Assert(trx.Error, check.IsNil) sendUpdate, err := db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) @@ -193,7 +199,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, Hostinfo: &hostInfo2, } db.DB.Save(&node2) @@ -247,16 +253,18 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { } now := time.Now() + pakID := uint(pak.ID) node1 := types.Node{ ID: 1, Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, Hostinfo: &hostInfo1, LastSeen: &now, } - db.DB.Save(&node1) + trx := db.DB.Save(&node1) + c.Assert(trx.Error, check.IsNil) sendUpdate, err := db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) @@ -617,7 +625,16 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { db := dbForTest(t, tt.name) + user := types.User{Name: tt.name} + if err := db.DB.Save(&user).Error; err != nil { + t.Fatalf("failed to create user: %s", err) + } + for _, route := range tt.routes { + route.Node.User = user + if err := db.DB.Save(&route.Node).Error; err != nil { + t.Fatalf("failed to create node: %s", err) + } if err := db.DB.Save(&route).Error; err != nil { t.Fatalf("failed to create route: %s", err) } @@ -1013,8 +1030,16 @@ func TestFailoverRouteTx(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db := dbForTest(t, tt.name) + user := types.User{Name: "test"} + if err := db.DB.Save(&user).Error; err != nil { + t.Fatalf("failed to create user: %s", err) + } for _, route := range tt.routes { + route.Node.User = user + if err := db.DB.Save(&route.Node).Error; err != nil { + t.Fatalf("failed to create node: %s", err) + } if err := db.DB.Save(&route).Error; err != nil { t.Fatalf("failed to create route: %s", err) } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index b36e861..98dea6c 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -46,14 +46,16 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) + pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) err = db.DestroyUser("test") c.Assert(err, check.Equals, ErrUserStillHasNodes) @@ -98,14 +100,16 @@ func (s *Suite) TestSetMachineUser(c *check.C) { pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) c.Assert(err, check.IsNil) + pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "testnode", UserID: oldUser.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakID, } - db.DB.Save(&node) + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) c.Assert(node.UserID, check.Equals, oldUser.ID) err = db.AssignNodeToUser(&node, newUser.Name) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index f624847..2ba3d03 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -187,10 +187,9 @@ func Test_fullMapResponse(t *testing.T) { UserID: 0, User: types.User{Name: "mini"}, ForcedTags: []string{}, - AuthKeyID: 0, - AuthKey: &types.PreAuthKey{}, - LastSeen: &lastSeen, - Expiry: &expire, + AuthKey: &types.PreAuthKey{}, + LastSeen: &lastSeen, + Expiry: &expire, Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{ { diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 229f0f8..47af68f 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -97,7 +97,6 @@ func TestTailNode(t *testing.T) { Name: "mini", }, ForcedTags: []string{}, - AuthKeyID: 0, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, Expiry: &expire, diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index b0afe99..7a5756a 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -108,20 +108,20 @@ type Node struct { // parts of headscale. GivenName string `gorm:"type:varchar(63);unique_index"` UserID uint - User User `gorm:"foreignKey:UserID"` + User User `gorm:"constraint:OnDelete:CASCADE;"` RegisterMethod string ForcedTags StringList // TODO(kradalby): This seems like irrelevant information? - AuthKeyID uint - AuthKey *PreAuthKey + AuthKeyID *uint `sql:"DEFAULT:NULL"` + AuthKey *PreAuthKey `gorm:"constraint:OnDelete:SET NULL;"` LastSeen *time.Time Expiry *time.Time - Routes []Route + Routes []Route `gorm:"constraint:OnDelete:CASCADE;"` CreatedAt time.Time UpdatedAt time.Time diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 0d8c9cf..8b02569 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -14,11 +14,11 @@ type PreAuthKey struct { ID uint64 `gorm:"primary_key"` Key string UserID uint - User User + User User `gorm:"constraint:OnDelete:CASCADE;"` Reusable bool - Ephemeral bool `gorm:"default:false"` - Used bool `gorm:"default:false"` - ACLTags []PreAuthKeyACLTag + Ephemeral bool `gorm:"default:false"` + Used bool `gorm:"default:false"` + ACLTags []PreAuthKeyACLTag `gorm:"constraint:OnDelete:CASCADE;"` CreatedAt *time.Time Expiration *time.Time