From b918aa03fc360fd9faf7c36b404b0ab1c4cfc9c6 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 21 Nov 2023 18:20:06 +0100 Subject: [PATCH] move to use tailscfg types over strings/custom types (#1612) * rename database only fields Signed-off-by: Kristoffer Dalby * use correct endpoint type over string list Signed-off-by: Kristoffer Dalby * remove HostInfo wrapper Signed-off-by: Kristoffer Dalby * wrap errors in database hooks Signed-off-by: Kristoffer Dalby --------- Signed-off-by: Kristoffer Dalby --- hscontrol/db/node_test.go | 3 +- hscontrol/db/routes.go | 2 +- hscontrol/db/routes_test.go | 18 ++--- hscontrol/grpcv1.go | 2 +- hscontrol/mapper/mapper.go | 2 +- hscontrol/mapper/mapper_test.go | 9 +-- hscontrol/mapper/tail.go | 15 ++--- hscontrol/mapper/tail_test.go | 9 +-- hscontrol/policy/acls.go | 18 +++-- hscontrol/policy/acls_test.go | 74 ++++++++++++--------- hscontrol/poll.go | 9 ++- hscontrol/types/common.go | 27 -------- hscontrol/types/node.go | 113 ++++++++++++++++++-------------- 13 files changed, 147 insertions(+), 154 deletions(-) diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index d63611b..be13f66 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -12,6 +12,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" + "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -593,7 +594,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:exit"}, RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, }, diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index d73c3af..545bd2f 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -274,7 +274,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error { } advertisedRoutes := map[netip.Prefix]bool{} - for _, prefix := range node.HostInfo.RoutableIPs { + for _, prefix := range node.Hostinfo.RoutableIPs { advertisedRoutes[prefix] = false } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 02959e6..92730af 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo(hostInfo), + Hostinfo: &hostInfo, } db.db.Save(&node) @@ -81,7 +81,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo(hostInfo), + Hostinfo: &hostInfo, } db.db.Save(&node) @@ -152,7 +152,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo(hostInfo1), + Hostinfo: &hostInfo1, } db.db.Save(&node1) @@ -174,7 +174,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo(hostInfo2), + Hostinfo: &hostInfo2, } db.db.Save(&node2) @@ -232,7 +232,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) { UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo(hostInfo1), + Hostinfo: &hostInfo1, LastSeen: &now, } db.db.Save(&node1) @@ -266,7 +266,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) { UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo(hostInfo2), + Hostinfo: &hostInfo2, LastSeen: &now, } db.db.Save(&node2) @@ -313,9 +313,9 @@ func (s *Suite) TestSubnetFailover(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 1) - node2.HostInfo = types.HostInfo(tailcfg.Hostinfo{ + node2.Hostinfo = &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{prefix, prefix2}, - }) + } err = db.db.Save(&node2).Error c.Assert(err, check.IsNil) @@ -368,7 +368,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo(hostInfo1), + Hostinfo: &hostInfo1, LastSeen: &now, } db.db.Save(&node1) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 5c05146..9139513 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -550,7 +550,7 @@ func (api headscaleV1APIServer) DebugCreateNode( Expiry: &time.Time{}, LastSeen: &time.Time{}, - HostInfo: types.HostInfo(hostinfo), + Hostinfo: &hostinfo, } log.Debug(). diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 806e901..a9028a9 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -195,7 +195,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { attrs := url.Values{ "device_name": []string{node.Hostname}, - "device_model": []string{node.HostInfo.OS}, + "device_model": []string{node.Hostinfo.OS}, } if len(node.IPAddresses) > 0 { diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index fd314ae..094a6c7 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -186,8 +186,7 @@ func Test_fullMapResponse(t *testing.T) { AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, Expiry: &expire, - HostInfo: types.HostInfo{}, - Endpoints: []string{}, + Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{ { Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")), @@ -267,8 +266,7 @@ func Test_fullMapResponse(t *testing.T) { ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, - HostInfo: types.HostInfo{}, - Endpoints: []string{}, + Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{}, CreatedAt: created, } @@ -324,8 +322,7 @@ func Test_fullMapResponse(t *testing.T) { ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, - HostInfo: types.HostInfo{}, - Endpoints: []string{}, + Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{}, CreatedAt: created, } diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index c5cee63..a436772 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -72,8 +72,8 @@ func tailNode( } var derp string - if node.HostInfo.NetInfo != nil { - derp = fmt.Sprintf("127.3.3.40:%d", node.HostInfo.NetInfo.PreferredDERP) + if node.Hostinfo.NetInfo != nil { + derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP) } else { derp = "127.3.3.40:0" // Zero means disconnected or unknown. } @@ -90,18 +90,11 @@ func tailNode( return nil, err } - hostInfo := node.GetHostInfo() - online := node.IsOnline() tags, _ := pol.TagsOfNode(node) tags = lo.Uniq(append(tags, node.ForcedTags...)) - endpoints, err := node.EndpointsToAddrPort() - if err != nil { - return nil, err - } - tNode := tailcfg.Node{ ID: tailcfg.NodeID(node.ID), // this is the actual ID StableID: tailcfg.StableNodeID( @@ -118,9 +111,9 @@ func tailNode( DiscoKey: node.DiscoKey, Addresses: addrs, AllowedIPs: allowedIPs, - Endpoints: endpoints, + Endpoints: node.Endpoints, DERP: derp, - Hostinfo: hostInfo.View(), + Hostinfo: node.Hostinfo.View(), Created: node.CreatedAt, Tags: tags, diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 734c27f..936f275 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -53,8 +53,10 @@ func TestTailNode(t *testing.T) { wantErr bool }{ { - name: "empty-node", - node: &types.Node{}, + name: "empty-node", + node: &types.Node{ + Hostinfo: &tailcfg.Hostinfo{}, + }, pol: &policy.ACLPolicy{}, dnsConfig: &tailcfg.DNSConfig{}, baseDomain: "", @@ -102,8 +104,7 @@ func TestTailNode(t *testing.T) { AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, Expiry: &expire, - HostInfo: types.HostInfo{}, - Endpoints: []string{}, + Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{ { Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")), diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 08ce800..11f280a 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -596,10 +596,13 @@ func excludeCorrectlyTaggedNodes( } // for each node if tag is in tags list, don't append it. for _, node := range nodes { - hi := node.GetHostInfo() - found := false - for _, t := range hi.RequestTags { + + if node.Hostinfo == nil { + continue + } + + for _, t := range node.Hostinfo.RequestTags { if util.StringOrPrefixListContains(tags, t) { found = true @@ -787,8 +790,11 @@ func (pol *ACLPolicy) expandIPsFromTag( for _, user := range owners { nodes := filterNodesByUser(nodes, user) for _, node := range nodes { - hi := node.GetHostInfo() - if util.StringOrPrefixListContains(hi.RequestTags, alias) { + if node.Hostinfo == nil { + continue + } + + if util.StringOrPrefixListContains(node.Hostinfo.RequestTags, alias) { node.IPAddresses.AppendToIPSet(&build) } } @@ -882,7 +888,7 @@ func (pol *ACLPolicy) TagsOfNode( validTagMap := make(map[string]bool) invalidTagMap := make(map[string]bool) - for _, tag := range node.HostInfo.RequestTags { + for _, tag := range node.Hostinfo.RequestTags { owners, err := expandOwnersFromTag(pol, tag) if errors.Is(err, ErrInvalidTag) { invalidTagMap[tag] = true diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 661c9cf..aca1b49 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -418,6 +418,7 @@ acls: User: types.User{ Name: "testuser", }, + Hostinfo: &tailcfg.Hostinfo{}, }, }) @@ -1264,7 +1265,7 @@ func Test_expandAlias(t *testing.T) { netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:hr-webserver"}, @@ -1275,7 +1276,7 @@ func Test_expandAlias(t *testing.T) { netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:hr-webserver"}, @@ -1405,7 +1406,7 @@ func Test_expandAlias(t *testing.T) { netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:hr-webserver"}, @@ -1443,7 +1444,7 @@ func Test_expandAlias(t *testing.T) { netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, @@ -1454,7 +1455,7 @@ func Test_expandAlias(t *testing.T) { netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, @@ -1464,13 +1465,15 @@ func Test_expandAlias(t *testing.T) { IPAddresses: types.NodeAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: types.User{Name: "marc"}, + User: types.User{Name: "marc"}, + Hostinfo: &tailcfg.Hostinfo{}, }, &types.Node{ IPAddresses: types.NodeAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: types.User{Name: "joe"}, + User: types.User{Name: "joe"}, + Hostinfo: &tailcfg.Hostinfo{}, }, }, }, @@ -1520,7 +1523,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, @@ -1531,7 +1534,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, @@ -1541,7 +1544,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { IPAddresses: types.NodeAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: types.User{Name: "joe"}, + User: types.User{Name: "joe"}, + Hostinfo: &tailcfg.Hostinfo{}, }, }, user: "joe", @@ -1550,6 +1554,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { &types.Node{ IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")}, User: types.User{Name: "joe"}, + Hostinfo: &tailcfg.Hostinfo{}, }, }, }, @@ -1570,7 +1575,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, @@ -1581,7 +1586,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, @@ -1591,7 +1596,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { IPAddresses: types.NodeAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: types.User{Name: "joe"}, + User: types.User{Name: "joe"}, + Hostinfo: &tailcfg.Hostinfo{}, }, }, user: "joe", @@ -1600,6 +1606,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { &types.Node{ IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")}, User: types.User{Name: "joe"}, + Hostinfo: &tailcfg.Hostinfo{}, }, }, }, @@ -1615,7 +1622,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, @@ -1627,12 +1634,14 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { }, User: types.User{Name: "joe"}, ForcedTags: []string{"tag:accountant-webserver"}, + Hostinfo: &tailcfg.Hostinfo{}, }, &types.Node{ IPAddresses: types.NodeAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: types.User{Name: "joe"}, + User: types.User{Name: "joe"}, + Hostinfo: &tailcfg.Hostinfo{}, }, }, user: "joe", @@ -1641,6 +1650,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { &types.Node{ IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")}, User: types.User{Name: "joe"}, + Hostinfo: &tailcfg.Hostinfo{}, }, }, }, @@ -1656,7 +1666,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "hr-web1", RequestTags: []string{"tag:hr-webserver"}, @@ -1667,7 +1677,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "hr-web2", RequestTags: []string{"tag:hr-webserver"}, @@ -1677,7 +1687,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { IPAddresses: types.NodeAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: types.User{Name: "joe"}, + User: types.User{Name: "joe"}, + Hostinfo: &tailcfg.Hostinfo{}, }, }, user: "joe", @@ -1688,7 +1699,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "hr-web1", RequestTags: []string{"tag:hr-webserver"}, @@ -1699,7 +1710,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "hr-web2", RequestTags: []string{"tag:hr-webserver"}, @@ -1709,7 +1720,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { IPAddresses: types.NodeAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: types.User{Name: "joe"}, + User: types.User{Name: "joe"}, + Hostinfo: &tailcfg.Hostinfo{}, }, }, }, @@ -1952,7 +1964,7 @@ func Test_getTags(t *testing.T) { User: types.User{ Name: "joe", }, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:valid"}, }, }, @@ -1972,7 +1984,7 @@ func Test_getTags(t *testing.T) { User: types.User{ Name: "joe", }, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:valid", "tag:invalid"}, }, }, @@ -1992,7 +2004,7 @@ func Test_getTags(t *testing.T) { User: types.User{ Name: "joe", }, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{ "tag:invalid", "tag:valid", @@ -2016,7 +2028,7 @@ func Test_getTags(t *testing.T) { User: types.User{ Name: "joe", }, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:invalid", "very-invalid"}, }, }, @@ -2032,7 +2044,7 @@ func Test_getTags(t *testing.T) { User: types.User{ Name: "joe", }, - HostInfo: types.HostInfo{ + Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:invalid", "very-invalid"}, }, }, @@ -3010,7 +3022,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, - HostInfo: types.HostInfo(hostInfo), + Hostinfo: &hostInfo, } pol := &ACLPolicy{ @@ -3062,7 +3074,7 @@ func TestInvalidTagValidUser(t *testing.T) { Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, - HostInfo: types.HostInfo(hostInfo), + Hostinfo: &hostInfo, } pol := &ACLPolicy{ @@ -3113,7 +3125,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, - HostInfo: types.HostInfo(hostInfo), + Hostinfo: &hostInfo, } pol := &ACLPolicy{ @@ -3174,7 +3186,7 @@ func TestValidTagInvalidUser(t *testing.T) { Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, - HostInfo: types.HostInfo(hostInfo), + Hostinfo: &hostInfo, } hostInfo2 := tailcfg.Hostinfo{ @@ -3191,7 +3203,7 @@ func TestValidTagInvalidUser(t *testing.T) { Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, - HostInfo: types.HostInfo(hostInfo2), + Hostinfo: &hostInfo2, } pol := &ACLPolicy{ diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 050e857..7935381 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -83,15 +83,14 @@ func (h *Headscale) handlePoll( Bool("stream", mapRequest.Stream). Str("node_key", node.NodeKey.ShortString()). Str("node", node.Hostname). - Strs("endpoints", node.Endpoints). Msg("Received endpoint update") now := time.Now().UTC() node.LastSeen = &now node.Hostname = mapRequest.Hostinfo.Hostname - node.HostInfo = types.HostInfo(*mapRequest.Hostinfo) + node.Hostinfo = mapRequest.Hostinfo node.DiscoKey = mapRequest.DiscoKey - node.SetEndpointsFromAddrPorts(mapRequest.Endpoints) + node.Endpoints = mapRequest.Endpoints if err := h.db.NodeSave(node); err != nil { logErr(err, "Failed to persist/update node in the database") @@ -142,9 +141,9 @@ func (h *Headscale) handlePoll( now := time.Now().UTC() node.LastSeen = &now node.Hostname = mapRequest.Hostinfo.Hostname - node.HostInfo = types.HostInfo(*mapRequest.Hostinfo) + node.Hostinfo = mapRequest.Hostinfo node.DiscoKey = mapRequest.DiscoKey - node.SetEndpointsFromAddrPorts(mapRequest.Endpoints) + node.Endpoints = mapRequest.Endpoints // When a node connects to control, list the peers it has at // that given point, further updates are kept in memory in diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index b275fa4..39060ac 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -12,33 +12,6 @@ import ( var ErrCannotParsePrefix = errors.New("cannot parse prefix") -// This is a "wrapper" type around tailscales -// Hostinfo to allow us to add database "serialization" -// methods. This allows us to use a typed values throughout -// the code and not have to marshal/unmarshal and error -// check all over the code. -type HostInfo tailcfg.Hostinfo - -func (hi *HostInfo) Scan(destination interface{}) error { - switch value := destination.(type) { - case []byte: - return json.Unmarshal(value, hi) - - case string: - return json.Unmarshal([]byte(value), hi) - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (hi HostInfo) Value() (driver.Value, error) { - bytes, err := json.Marshal(hi) - - return string(bytes), err -} - type IPPrefix netip.Prefix func (i *IPPrefix) Scan(destination interface{}) error { diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index da20bc4..a2fdb91 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -2,6 +2,7 @@ package types import ( "database/sql/driver" + "encoding/json" "errors" "fmt" "net/netip" @@ -27,27 +28,40 @@ var ( type Node struct { ID uint64 `gorm:"primary_key"` - // MachineKeyValue is the string representation of MachineKey + // MachineKeyDatabaseField is the string representation of MachineKey // it is _only_ used for reading and writing the key to the // database and should not be used. // Use MachineKey instead. - MachineKeyValue string `gorm:"column:machine_key;unique_index"` + MachineKeyDatabaseField string `gorm:"column:machine_key;unique_index"` + MachineKey key.MachinePublic `gorm:"-"` - // NodeKeyValue is the string representation of NodeKey + // NodeKeyDatabaseField is the string representation of NodeKey // it is _only_ used for reading and writing the key to the // database and should not be used. // Use NodeKey instead. - NodeKeyValue string `gorm:"column:node_key"` + NodeKeyDatabaseField string `gorm:"column:node_key"` + NodeKey key.NodePublic `gorm:"-"` - // DiscoKeyValue is the string representation of DiscoKey + // DiscoKeyDatabaseField is the string representation of DiscoKey // it is _only_ used for reading and writing the key to the // database and should not be used. // Use DiscoKey instead. - DiscoKeyValue string `gorm:"column:disco_key"` + DiscoKeyDatabaseField string `gorm:"column:disco_key"` + DiscoKey key.DiscoPublic `gorm:"-"` - MachineKey key.MachinePublic `gorm:"-"` - NodeKey key.NodePublic `gorm:"-"` - DiscoKey key.DiscoPublic `gorm:"-"` + // EndpointsDatabaseField is the string list representation of Endpoints + // it is _only_ used for reading and writing the key to the + // database and should not be used. + // Use Endpoints instead. + EndpointsDatabaseField StringList `gorm:"column:endpoints"` + Endpoints []netip.AddrPort `gorm:"-"` + + // EndpointsDatabaseField is the string list representation of Endpoints + // it is _only_ used for reading and writing the key to the + // database and should not be used. + // Use Endpoints instead. + HostinfoDatabaseField string `gorm:"column:hostinfo"` + Hostinfo *tailcfg.Hostinfo `gorm:"-"` IPAddresses NodeAddresses @@ -76,9 +90,6 @@ type Node struct { LastSeen *time.Time Expiry *time.Time - HostInfo HostInfo - Endpoints StringList - Routes []Route CreatedAt time.Time @@ -195,31 +206,6 @@ func (node Node) IsExpired() bool { return time.Now().UTC().After(*node.Expiry) } -// TODO(kradalby): Try to replace the types in the DB to be correct. -func (node *Node) EndpointsToAddrPort() ([]netip.AddrPort, error) { - var ret []netip.AddrPort - for _, ep := range node.Endpoints { - addrPort, err := netip.ParseAddrPort(ep) - if err != nil { - return nil, err - } - - ret = append(ret, addrPort) - } - - return ret, nil -} - -// TODO(kradalby): Try to replace the types in the DB to be correct. -func (node *Node) SetEndpointsFromAddrPorts(in []netip.AddrPort) { - var strs StringList - for _, addrPort := range in { - strs = append(strs, addrPort.String()) - } - - node.Endpoints = strs -} - // IsOnline returns if the node is connected to Headscale. // This is really a naive implementation, as we don't really see // if there is a working connection between the client and the server. @@ -277,9 +263,22 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { // correctly in the database. // This currently means storing the keys as strings. func (n *Node) BeforeSave(tx *gorm.DB) (err error) { - n.MachineKeyValue = n.MachineKey.String() - n.NodeKeyValue = n.NodeKey.String() - n.DiscoKeyValue = n.DiscoKey.String() + n.MachineKeyDatabaseField = n.MachineKey.String() + n.NodeKeyDatabaseField = n.NodeKey.String() + n.DiscoKeyDatabaseField = n.DiscoKey.String() + + var endpoints StringList + for _, addrPort := range n.Endpoints { + endpoints = append(endpoints, addrPort.String()) + } + + n.EndpointsDatabaseField = endpoints + + hi, err := json.Marshal(n.Hostinfo) + if err != nil { + return fmt.Errorf("failed to marshal Hostinfo to store in db: %w", err) + } + n.HostinfoDatabaseField = string(hi) return } @@ -291,23 +290,40 @@ func (n *Node) BeforeSave(tx *gorm.DB) (err error) { // the proper types. func (n *Node) AfterFind(tx *gorm.DB) (err error) { var machineKey key.MachinePublic - if err := machineKey.UnmarshalText([]byte(n.MachineKeyValue)); err != nil { - return err + if err := machineKey.UnmarshalText([]byte(n.MachineKeyDatabaseField)); err != nil { + return fmt.Errorf("failed to unmarshal machine key from db: %w", err) } n.MachineKey = machineKey var nodeKey key.NodePublic - if err := nodeKey.UnmarshalText([]byte(n.NodeKeyValue)); err != nil { - return err + if err := nodeKey.UnmarshalText([]byte(n.NodeKeyDatabaseField)); err != nil { + return fmt.Errorf("failed to unmarshal node key from db: %w", err) } n.NodeKey = nodeKey var discoKey key.DiscoPublic - if err := discoKey.UnmarshalText([]byte(n.DiscoKeyValue)); err != nil { - return err + if err := discoKey.UnmarshalText([]byte(n.DiscoKeyDatabaseField)); err != nil { + return fmt.Errorf("failed to unmarshal disco key from db: %w", err) } n.DiscoKey = discoKey + var endpoints []netip.AddrPort + for _, ep := range n.EndpointsDatabaseField { + addrPort, err := netip.ParseAddrPort(ep) + if err != nil { + return fmt.Errorf("failed to parse endpoint from db: %w", err) + } + + endpoints = append(endpoints, addrPort) + } + n.Endpoints = endpoints + + var hi tailcfg.Hostinfo + if err := json.Unmarshal([]byte(n.HostinfoDatabaseField), &hi); err != nil { + return fmt.Errorf("failed to unmarshal Hostinfo from db: %w", err) + } + n.Hostinfo = &hi + return } @@ -346,11 +362,6 @@ func (node *Node) Proto() *v1.Node { return nodeProto } -// GetHostInfo returns a Hostinfo struct for the node. -func (node *Node) GetHostInfo() tailcfg.Hostinfo { - return tailcfg.Hostinfo(node.HostInfo) -} - func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (string, error) { var hostname string if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS