diff --git a/.gitignore b/.gitignore index bcbc9b2..0ba8193 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,5 @@ integration_test/etc/config.dump.yaml # MkDocs .cache /site + +__debug_bin diff --git a/cmd/headscale/cli/api_key.go b/cmd/headscale/cli/api_key.go index f7c7e3a..37ef423 100644 --- a/cmd/headscale/cli/api_key.go +++ b/cmd/headscale/cli/api_key.go @@ -6,7 +6,7 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" "github.com/pterm/pterm" "github.com/rs/zerolog/log" @@ -83,7 +83,7 @@ var listAPIKeys = &cobra.Command{ } tableData = append(tableData, []string{ - strconv.FormatUint(key.GetId(), hscontrol.Base10), + strconv.FormatUint(key.GetId(), util.Base10), key.GetPrefix(), expiration, key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat), diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index f2c8028..7e8e92d 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -4,7 +4,7 @@ import ( "fmt" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "google.golang.org/grpc/status" @@ -93,7 +93,7 @@ var createNodeCmd = &cobra.Command{ return } - if !hscontrol.NodePublicKeyRegex.Match([]byte(machineKey)) { + if !util.NodePublicKeyRegex.Match([]byte(machineKey)) { err = errPreAuthKeyMalformed ErrorOutput( err, diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 772b428..31a0677 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -10,7 +10,7 @@ import ( survey "github.com/AlecAivazis/survey/v2" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/pterm/pterm" "github.com/spf13/cobra" "google.golang.org/grpc/status" @@ -529,7 +529,7 @@ func nodesToPtables( var machineKey key.MachinePublic err := machineKey.UnmarshalText( - []byte(hscontrol.MachinePublicKeyEnsurePrefix(machine.MachineKey)), + []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), ) if err != nil { machineKey = key.MachinePublic{} @@ -537,7 +537,7 @@ func nodesToPtables( var nodeKey key.NodePublic err = nodeKey.UnmarshalText( - []byte(hscontrol.NodePublicKeyEnsurePrefix(machine.NodeKey)), + []byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey)), ) if err != nil { return nil, err @@ -596,7 +596,7 @@ func nodesToPtables( } nodeData := []string{ - strconv.FormatUint(machine.Id, hscontrol.Base10), + strconv.FormatUint(machine.Id, util.Base10), machine.Name, machine.GetGivenName(), machineKey.ShortString(), diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 3724fe9..3132e99 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -1,11 +1,11 @@ package cli import ( + "errors" "fmt" survey "github.com/AlecAivazis/survey/v2" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" "github.com/pterm/pterm" "github.com/rs/zerolog/log" "github.com/spf13/cobra" @@ -20,9 +20,7 @@ func init() { userCmd.AddCommand(renameUserCmd) } -const ( - errMissingParameter = hscontrol.Error("missing parameters") -) +var errMissingParameter = errors.New("missing parameters") var userCmd = &cobra.Command{ Use: "users", diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index a2a5d59..2831dbf 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -10,6 +10,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -39,7 +40,7 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) { // We are doing this here, as in the future could be cool to have it also hot-reload if cfg.ACL.PolicyPath != "" { - aclPath := hscontrol.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) + aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) err = app.LoadACLPolicyFromPath(aclPath) if err != nil { log.Fatal(). @@ -98,7 +99,7 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc. grpcOptions = append( grpcOptions, grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithContextDialer(hscontrol.GrpcSocketDialer), + grpc.WithContextDialer(util.GrpcSocketDialer), ) } else { // If we are not connecting to a local server, require an API key for authentication diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 1b98731..89fd775 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/spf13/viper" "gopkg.in/check.v1" ) @@ -64,7 +65,7 @@ func (*Suite) TestConfigFileLoading(c *check.C) { c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1") c.Assert( - hscontrol.GetFileMode("unix_socket_permission"), + util.GetFileMode("unix_socket_permission"), check.Equals, fs.FileMode(0o770), ) @@ -107,7 +108,7 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1") c.Assert( - hscontrol.GetFileMode("unix_socket_permission"), + util.GetFileMode("unix_socket_permission"), check.Equals, fs.FileMode(0o770), ) diff --git a/hscontrol/acls.go b/hscontrol/acls.go index 449c7ff..2c81046 100644 --- a/hscontrol/acls.go +++ b/hscontrol/acls.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/tailscale/hujson" "go4.org/netipx" @@ -20,21 +21,16 @@ import ( "tailscale.com/tailcfg" ) -const ( - errEmptyPolicy = Error("empty policy") - errInvalidAction = Error("invalid action") - errInvalidGroup = Error("invalid group") - errInvalidTag = Error("invalid tag") - errInvalidPortFormat = Error("invalid port format") - errWildcardIsNeeded = Error("wildcard as port is required for the protocol") +var ( + errEmptyPolicy = errors.New("empty policy") + errInvalidAction = errors.New("invalid action") + errInvalidGroup = errors.New("invalid group") + errInvalidTag = errors.New("invalid tag") + errInvalidPortFormat = errors.New("invalid port format") + errWildcardIsNeeded = errors.New("wildcard as port is required for the protocol") ) const ( - Base8 = 8 - Base10 = 10 - BitSize16 = 16 - BitSize32 = 32 - BitSize64 = 64 portRangeBegin = 0 portRangeEnd = 65535 expectedTokenItems = 2 @@ -123,7 +119,7 @@ func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error { } func (h *Headscale) UpdateACLRules() error { - machines, err := h.ListMachines() + machines, err := h.db.ListMachines() if err != nil { return err } @@ -230,7 +226,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { return nil, errEmptyPolicy } - machines, err := h.ListMachines() + machines, err := h.db.ListMachines() if err != nil { return nil, err } @@ -570,7 +566,7 @@ func excludeCorrectlyTaggedNodes( for tag := range aclPolicy.TagOwners { owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain) ns := append(owners, user) - if contains(ns, user) { + if util.StringOrPrefixListContains(ns, user) { tags = append(tags, tag) } } @@ -580,7 +576,7 @@ func excludeCorrectlyTaggedNodes( found := false for _, t := range hi.RequestTags { - if contains(tags, t) { + if util.StringOrPrefixListContains(tags, t) { found = true break @@ -614,7 +610,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err rang := strings.Split(portStr, "-") switch len(rang) { case 1: - port, err := strconv.ParseUint(rang[0], Base10, BitSize16) + port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16) if err != nil { return nil, err } @@ -624,11 +620,11 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err }) case expectedTokenItems: - start, err := strconv.ParseUint(rang[0], Base10, BitSize16) + start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16) if err != nil { return nil, err } - last, err := strconv.ParseUint(rang[1], Base10, BitSize16) + last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16) if err != nil { return nil, err } @@ -754,7 +750,7 @@ func (pol *ACLPolicy) getIPsFromTag( // check for forced tags for _, machine := range machines { - if contains(machine.ForcedTags, alias) { + if util.StringOrPrefixListContains(machine.ForcedTags, alias) { machine.IPAddresses.AppendToIPSet(&build) } } @@ -783,7 +779,7 @@ func (pol *ACLPolicy) getIPsFromTag( machines := filterMachinesByUser(machines, user) for _, machine := range machines { hi := machine.GetHostInfo() - if contains(hi.RequestTags, alias) { + if util.StringOrPrefixListContains(hi.RequestTags, alias) { machine.IPAddresses.AppendToIPSet(&build) } } diff --git a/hscontrol/acls_test.go b/hscontrol/acls_test.go index 095597f..70a57b8 100644 --- a/hscontrol/acls_test.go +++ b/hscontrol/acls_test.go @@ -30,8 +30,8 @@ func (s *Suite) TestBrokenHuJson(c *check.C) { func (s *Suite) TestInvalidPolicyHuson(c *check.C) { acl := []byte(` { - "valid_json": true, - "but_a_policy_though": false + "valid_json": true, + "but_a_policy_though": false } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -60,129 +60,129 @@ func (s *Suite) TestParseInvalidCIDR(c *check.C) { func (s *Suite) TestRuleInvalidGeneration(c *check.C) { acl := []byte(` { - // Declare static groups of users beyond those in the identity service. - "groups": { - "group:example": [ - "user1@example.com", - "user2@example.com", - ], - }, - // Declare hostname aliases to use in place of IP addresses or subnets. - "hosts": { - "example-host-1": "100.100.100.100", - "example-host-2": "100.100.101.100/24", - }, - // Define who is allowed to use which tags. - "tagOwners": { - // Everyone in the montreal-admins or global-admins group are - // allowed to tag servers as montreal-webserver. - "tag:montreal-webserver": [ - "group:montreal-admins", - "group:global-admins", - ], - // Only a few admins are allowed to create API servers. - "tag:api-server": [ - "group:global-admins", - "example-host-1", - ], - }, - // Access control lists. - "acls": [ - // Engineering users, plus the president, can access port 22 (ssh) - // and port 3389 (remote desktop protocol) on all servers, and all - // ports on git-server or ci-server. - { - "action": "accept", - "src": [ - "group:engineering", - "president@example.com" - ], - "dst": [ - "*:22,3389", - "git-server:*", - "ci-server:*" - ], - }, - // Allow engineer users to access any port on a device tagged with - // tag:production. - { - "action": "accept", - "src": [ - "group:engineers" - ], - "dst": [ - "tag:production:*" - ], - }, - // Allow servers in the my-subnet host and 192.168.1.0/24 to access hosts - // on both networks. - { - "action": "accept", - "src": [ - "my-subnet", - "192.168.1.0/24" - ], - "dst": [ - "my-subnet:*", - "192.168.1.0/24:*" - ], - }, - // Allow every user of your network to access anything on the network. - // Comment out this section if you want to define specific ACL - // restrictions above. - { - "action": "accept", - "src": [ - "*" - ], - "dst": [ - "*:*" - ], - }, - // All users in Montreal are allowed to access the Montreal web - // servers. - { - "action": "accept", - "src": [ - "group:montreal-users" - ], - "dst": [ - "tag:montreal-webserver:80,443" - ], - }, - // Montreal web servers are allowed to make outgoing connections to - // the API servers, but only on https port 443. - // In contrast, this doesn't grant API servers the right to initiate - // any connections. - { - "action": "accept", - "src": [ - "tag:montreal-webserver" - ], - "dst": [ - "tag:api-server:443" - ], - }, - ], - // Declare tests to check functionality of ACL rules - "tests": [ - { - "src": "user1@example.com", - "accept": [ - "example-host-1:22", - "example-host-2:80" - ], - "deny": [ - "exapmle-host-2:100" - ], - }, - { - "src": "user2@example.com", - "accept": [ - "100.60.3.4:22" - ], - }, - ], + // Declare static groups of users beyond those in the identity service. + "groups": { + "group:example": [ + "user1@example.com", + "user2@example.com", + ], + }, + // Declare hostname aliases to use in place of IP addresses or subnets. + "hosts": { + "example-host-1": "100.100.100.100", + "example-host-2": "100.100.101.100/24", + }, + // Define who is allowed to use which tags. + "tagOwners": { + // Everyone in the montreal-admins or global-admins group are + // allowed to tag servers as montreal-webserver. + "tag:montreal-webserver": [ + "group:montreal-admins", + "group:global-admins", + ], + // Only a few admins are allowed to create API servers. + "tag:api-server": [ + "group:global-admins", + "example-host-1", + ], + }, + // Access control lists. + "acls": [ + // Engineering users, plus the president, can access port 22 (ssh) + // and port 3389 (remote desktop protocol) on all servers, and all + // ports on git-server or ci-server. + { + "action": "accept", + "src": [ + "group:engineering", + "president@example.com" + ], + "dst": [ + "*:22,3389", + "git-server:*", + "ci-server:*" + ], + }, + // Allow engineer users to access any port on a device tagged with + // tag:production. + { + "action": "accept", + "src": [ + "group:engineers" + ], + "dst": [ + "tag:production:*" + ], + }, + // Allow servers in the my-subnet host and 192.168.1.0/24 to access hosts + // on both networks. + { + "action": "accept", + "src": [ + "my-subnet", + "192.168.1.0/24" + ], + "dst": [ + "my-subnet:*", + "192.168.1.0/24:*" + ], + }, + // Allow every user of your network to access anything on the network. + // Comment out this section if you want to define specific ACL + // restrictions above. + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "*:*" + ], + }, + // All users in Montreal are allowed to access the Montreal web + // servers. + { + "action": "accept", + "src": [ + "group:montreal-users" + ], + "dst": [ + "tag:montreal-webserver:80,443" + ], + }, + // Montreal web servers are allowed to make outgoing connections to + // the API servers, but only on https port 443. + // In contrast, this doesn't grant API servers the right to initiate + // any connections. + { + "action": "accept", + "src": [ + "tag:montreal-webserver" + ], + "dst": [ + "tag:api-server:443" + ], + }, + ], + // Declare tests to check functionality of ACL rules + "tests": [ + { + "src": "user1@example.com", + "accept": [ + "example-host-1:22", + "example-host-2:80" + ], + "deny": [ + "exapmle-host-2:100" + ], + }, + { + "src": "user2@example.com", + "accept": [ + "100.60.3.4:22" + ], + }, + ], } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -192,24 +192,24 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { func (s *Suite) TestBasicRule(c *check.C) { acl := []byte(` { - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, - "acls": [ - { - "action": "accept", - "src": [ - "subnet-1", - "192.168.1.0/24" - ], - "dst": [ - "*:22,3389", - "host-1:*", - ], - }, - ], + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + "192.168.1.0/24" + ], + "dst": [ + "*:22,3389", + "host-1:*", + ], + }, + ], } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -238,13 +238,13 @@ func (s *Suite) TestInvalidAction(c *check.C) { func (s *Suite) TestSshRules(c *check.C) { envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1") - user, err := app.CreateUser("user1") + user, err := app.db.CreateUser("user1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user1", "testmachine") + _, err = app.db.GetMachine("user1", "testmachine") c.Assert(err, check.NotNil) hostInfo := tailcfg.Hostinfo{ OS: "centos", @@ -264,7 +264,7 @@ func (s *Suite) TestSshRules(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) app.aclPolicy = &ACLPolicy{ Groups: Groups{ @@ -348,13 +348,13 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. // the tag is matched in the Sources section. func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { - user, err := app.CreateUser("user1") + user, err := app.db.CreateUser("user1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user1", "testmachine") + _, err = app.db.GetMachine("user1", "testmachine") c.Assert(err, check.NotNil) hostInfo := tailcfg.Hostinfo{ OS: "centos", @@ -374,7 +374,7 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) app.aclPolicy = &ACLPolicy{ Groups: Groups{"group:test": []string{"user1", "user2"}}, @@ -398,13 +398,13 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. // the tag is matched in the Destinations section. func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { - user, err := app.CreateUser("user1") + user, err := app.db.CreateUser("user1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user1", "testmachine") + _, err = app.db.GetMachine("user1", "testmachine") c.Assert(err, check.NotNil) hostInfo := tailcfg.Hostinfo{ OS: "centos", @@ -424,7 +424,7 @@ func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) app.aclPolicy = &ACLPolicy{ Groups: Groups{"group:test": []string{"user1", "user2"}}, @@ -448,13 +448,13 @@ func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { // tag on a host that isn't owned by a tag owners. So the user // of the host should be valid. func (s *Suite) TestInvalidTagValidUser(c *check.C) { - user, err := app.CreateUser("user1") + user, err := app.db.CreateUser("user1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user1", "testmachine") + _, err = app.db.GetMachine("user1", "testmachine") c.Assert(err, check.NotNil) hostInfo := tailcfg.Hostinfo{ OS: "centos", @@ -474,7 +474,7 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) app.aclPolicy = &ACLPolicy{ TagOwners: TagOwners{"tag:test": []string{"user1"}}, @@ -497,13 +497,13 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) { // an ACL rule is matching the tag to a user. It should not be valid since the // host should be tied to the tag now. func (s *Suite) TestValidTagInvalidUser(c *check.C) { - user, err := app.CreateUser("user1") + user, err := app.db.CreateUser("user1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user1", "webserver") + _, err = app.db.GetMachine("user1", "webserver") c.Assert(err, check.NotNil) hostInfo := tailcfg.Hostinfo{ OS: "centos", @@ -523,8 +523,8 @@ func (s *Suite) TestValidTagInvalidUser(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) - _, err = app.GetMachine("user1", "user") + app.db.db.Save(&machine) + _, err = app.db.GetMachine("user1", "user") hostInfo2 := tailcfg.Hostinfo{ OS: "debian", Hostname: "Hostname", @@ -542,7 +542,7 @@ func (s *Suite) TestValidTagInvalidUser(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo2), } - app.db.Save(&machine) + app.db.db.Save(&machine) app.aclPolicy = &ACLPolicy{ TagOwners: TagOwners{"tag:webapp": []string{"user1"}}, @@ -571,22 +571,22 @@ func (s *Suite) TestValidTagInvalidUser(c *check.C) { func (s *Suite) TestPortRange(c *check.C) { acl := []byte(` { - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, - "acls": [ - { - "action": "accept", - "src": [ - "subnet-1", - ], - "dst": [ - "host-1:5400-5500", - ], - }, - ], + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + ], + "dst": [ + "host-1:5400-5500", + ], + }, + ], } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -605,43 +605,43 @@ func (s *Suite) TestPortRange(c *check.C) { func (s *Suite) TestProtocolParsing(c *check.C) { acl := []byte(` { - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, - "acls": [ - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "tcp", - "dst": [ - "host-1:*", - ], - }, - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "udp", - "dst": [ - "host-1:53", - ], - }, - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "icmp", - "dst": [ - "host-1:*", - ], - }, - ], + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "tcp", + "dst": [ + "host-1:*", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "udp", + "dst": [ + "host-1:53", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "icmp", + "dst": [ + "host-1:*", + ], + }, + ], } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -660,22 +660,22 @@ func (s *Suite) TestProtocolParsing(c *check.C) { func (s *Suite) TestPortWildcard(c *check.C) { acl := []byte(` { - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, - "acls": [ - { - "Action": "accept", - "src": [ - "*", - ], - "dst": [ - "host-1:*", - ], - }, - ], + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "dst": [ + "host-1:*", + ], + }, + ], } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -694,8 +694,7 @@ func (s *Suite) TestPortWildcard(c *check.C) { } func (s *Suite) TestPortWildcardYAML(c *check.C) { - acl := []byte(` ---- + acl := []byte(`--- hosts: host-1: 100.100.100.100/32 subnet-1: 100.100.101.100/24 @@ -704,8 +703,7 @@ acls: src: - "*" dst: - - host-1:* -`) + - host-1:*`) err := app.LoadACLPolicyFromBytes(acl, "yaml") c.Assert(err, check.IsNil) @@ -722,15 +720,15 @@ acls: } func (s *Suite) TestPortUser(c *check.C) { - user, err := app.CreateUser("testuser") + user, err := app.db.CreateUser("testuser") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("testuser", "testmachine") + _, err = app.db.GetMachine("testuser", "testmachine") c.Assert(err, check.NotNil) - ips, _ := app.getAvailableIPs() + ips, _ := app.db.getAvailableIPs() machine := Machine{ ID: 0, MachineKey: "12345", @@ -742,32 +740,32 @@ func (s *Suite) TestPortUser(c *check.C) { IPAddresses: ips, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) acl := []byte(` { - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, - "acls": [ - { - "action": "accept", - "src": [ - "testuser", - ], - "dst": [ - "host-1:*", - ], - }, - ], + "acls": [ + { + "action": "accept", + "src": [ + "testuser", + ], + "dst": [ + "host-1:*", + ], + }, + ], } `) err = app.LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) - machines, err := app.ListMachines() + machines, err := app.db.ListMachines() c.Assert(err, check.IsNil) rules, err := app.aclPolicy.generateFilterRules(machines, false) @@ -785,15 +783,15 @@ func (s *Suite) TestPortUser(c *check.C) { } func (s *Suite) TestPortGroup(c *check.C) { - user, err := app.CreateUser("testuser") + user, err := app.db.CreateUser("testuser") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("testuser", "testmachine") + _, err = app.db.GetMachine("testuser", "testmachine") c.Assert(err, check.NotNil) - ips, _ := app.getAvailableIPs() + ips, _ := app.db.getAvailableIPs() machine := Machine{ ID: 0, MachineKey: "foo", @@ -805,38 +803,38 @@ func (s *Suite) TestPortGroup(c *check.C) { IPAddresses: ips, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) acl := []byte(` { - "groups": { - "group:example": [ - "testuser", - ], - }, + "groups": { + "group:example": [ + "testuser", + ], + }, - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, - "acls": [ - { - "action": "accept", - "src": [ - "group:example", - ], - "dst": [ - "host-1:*", - ], - }, - ], + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], } `) err = app.LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) - machines, err := app.ListMachines() + machines, err := app.db.ListMachines() c.Assert(err, check.IsNil) rules, err := app.aclPolicy.generateFilterRules(machines, false) diff --git a/hscontrol/addresses.go b/hscontrol/addresses.go new file mode 100644 index 0000000..7f78935 --- /dev/null +++ b/hscontrol/addresses.go @@ -0,0 +1,98 @@ +// Codehere is mostly taken from github.com/tailscale/tailscale +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package hscontrol + +import ( + "errors" + "fmt" + "net/netip" + + "github.com/juanfont/headscale/hscontrol/util" + "go4.org/netipx" +) + +var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") + +func (hsdb *HSDatabase) getAvailableIPs() (MachineAddresses, error) { + var ips MachineAddresses + var err error + for _, ipPrefix := range hsdb.ipPrefixes { + var ip *netip.Addr + ip, err = hsdb.getAvailableIP(ipPrefix) + if err != nil { + return ips, err + } + ips = append(ips, *ip) + } + + return ips, err +} + +func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) { + usedIps, err := hsdb.getUsedIPs() + if err != nil { + return nil, err + } + + ipPrefixNetworkAddress, ipPrefixBroadcastAddress := util.GetIPPrefixEndpoints(ipPrefix) + + // Get the first IP in our prefix + ip := ipPrefixNetworkAddress.Next() + + for { + if !ipPrefix.Contains(ip) { + return nil, ErrCouldNotAllocateIP + } + + switch { + case ip.Compare(ipPrefixBroadcastAddress) == 0: + fallthrough + case usedIps.Contains(ip): + fallthrough + case ip == netip.Addr{} || ip.IsLoopback(): + ip = ip.Next() + + continue + + default: + return &ip, nil + } + } +} + +func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) { + // FIXME: This really deserves a better data model, + // but this was quick to get running and it should be enough + // to begin experimenting with a dual stack tailnet. + var addressesSlices []string + hsdb.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) + + var ips netipx.IPSetBuilder + for _, slice := range addressesSlices { + var machineAddresses MachineAddresses + err := machineAddresses.Scan(slice) + if err != nil { + return &netipx.IPSet{}, fmt.Errorf( + "failed to read ip from database: %w", + err, + ) + } + + for _, ip := range machineAddresses { + ips.Add(ip) + } + } + + ipSet, err := ips.IPSet() + if err != nil { + return &netipx.IPSet{}, fmt.Errorf( + "failed to build IP Set: %w", + err, + ) + } + + return ipSet, nil +} diff --git a/hscontrol/utils_test.go b/hscontrol/addresses_test.go similarity index 75% rename from hscontrol/utils_test.go rename to hscontrol/addresses_test.go index 436df8a..f3be93a 100644 --- a/hscontrol/utils_test.go +++ b/hscontrol/addresses_test.go @@ -8,7 +8,7 @@ import ( ) func (s *Suite) TestGetAvailableIp(c *check.C) { - ips, err := app.getAvailableIPs() + ips, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) @@ -19,16 +19,16 @@ func (s *Suite) TestGetAvailableIp(c *check.C) { } func (s *Suite) TestGetUsedIps(c *check.C) { - ips, err := app.getAvailableIPs() + ips, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) - user, err := app.CreateUser("test-ip") + user, err := app.db.CreateUser("test-ip") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := Machine{ @@ -42,9 +42,9 @@ func (s *Suite) TestGetUsedIps(c *check.C) { AuthKeyID: uint(pak.ID), IPAddresses: ips, } - app.db.Save(&machine) + app.db.db.Save(&machine) - usedIps, err := app.getUsedIPs() + usedIps, err := app.db.getUsedIPs() c.Assert(err, check.IsNil) @@ -56,7 +56,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) { c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true) c.Assert(usedIps.Contains(expected), check.Equals, true) - machine1, err := app.GetMachineByID(0) + machine1, err := app.db.GetMachineByID(0) c.Assert(err, check.IsNil) c.Assert(len(machine1.IPAddresses), check.Equals, 1) @@ -64,19 +64,19 @@ func (s *Suite) TestGetUsedIps(c *check.C) { } func (s *Suite) TestGetMultiIp(c *check.C) { - user, err := app.CreateUser("test-ip-multi") + user, err := app.db.CreateUser("test-ip-multi") c.Assert(err, check.IsNil) for index := 1; index <= 350; index++ { - app.ipAllocationMutex.Lock() + app.db.ipAllocationMutex.Lock() - ips, err := app.getAvailableIPs() + ips, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := Machine{ @@ -90,12 +90,12 @@ func (s *Suite) TestGetMultiIp(c *check.C) { AuthKeyID: uint(pak.ID), IPAddresses: ips, } - app.db.Save(&machine) + app.db.db.Save(&machine) - app.ipAllocationMutex.Unlock() + app.db.ipAllocationMutex.Unlock() } - usedIps, err := app.getUsedIPs() + usedIps, err := app.db.getUsedIPs() c.Assert(err, check.IsNil) expected0 := netip.MustParseAddr("10.27.0.1") @@ -117,7 +117,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { c.Assert(usedIps.Contains(expected300), check.Equals, true) // Check that we can read back the IPs - machine1, err := app.GetMachineByID(1) + machine1, err := app.db.GetMachineByID(1) c.Assert(err, check.IsNil) c.Assert(len(machine1.IPAddresses), check.Equals, 1) c.Assert( @@ -126,7 +126,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { netip.MustParseAddr("10.27.0.1"), ) - machine50, err := app.GetMachineByID(50) + machine50, err := app.db.GetMachineByID(50) c.Assert(err, check.IsNil) c.Assert(len(machine50.IPAddresses), check.Equals, 1) c.Assert( @@ -136,7 +136,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { ) expectedNextIP := netip.MustParseAddr("10.27.1.95") - nextIP, err := app.getAvailableIPs() + nextIP, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(nextIP), check.Equals, 1) @@ -144,7 +144,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { // If we call get Available again, we should receive // the same IP, as it has not been reserved. - nextIP2, err := app.getAvailableIPs() + nextIP2, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(nextIP2), check.Equals, 1) @@ -152,7 +152,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { } func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { - ips, err := app.getAvailableIPs() + ips, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) expected := netip.MustParseAddr("10.27.0.1") @@ -160,13 +160,13 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { c.Assert(len(ips), check.Equals, 1) c.Assert(ips[0].String(), check.Equals, expected.String()) - user, err := app.CreateUser("test-ip") + user, err := app.db.CreateUser("test-ip") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := Machine{ @@ -179,23 +179,11 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - ips2, err := app.getAvailableIPs() + ips2, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(ips2), check.Equals, 1) c.Assert(ips2[0].String(), check.Equals, expected.String()) } - -func (s *Suite) TestGenerateRandomStringDNSSafe(c *check.C) { - for i := 0; i < 100000; i++ { - str, err := GenerateRandomStringDNSSafe(8) - if err != nil { - c.Error(err) - } - if len(str) != 8 { - c.Error("invalid length", len(str), str) - } - } -} diff --git a/hscontrol/api.go b/hscontrol/api.go index f8b1496..8e30141 100644 --- a/hscontrol/api.go +++ b/hscontrol/api.go @@ -3,25 +3,28 @@ package hscontrol import ( "bytes" "encoding/json" + "errors" "html/template" "net/http" "time" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/types/key" ) const ( // TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. - registrationHoldoff = time.Second * 5 - reservedResponseHeaderSize = 4 - RegisterMethodAuthKey = "authkey" - RegisterMethodOIDC = "oidc" - RegisterMethodCLI = "cli" - ErrRegisterMethodCLIDoesNotSupportExpire = Error( - "machines registered with CLI does not support expire", - ) + registrationHoldoff = time.Second * 5 + reservedResponseHeaderSize = 4 + RegisterMethodAuthKey = "authkey" + RegisterMethodOIDC = "oidc" + RegisterMethodCLI = "cli" +) + +var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( + "machines registered with CLI does not support expire", ) func (h *Headscale) HealthHandler( @@ -53,7 +56,7 @@ func (h *Headscale) HealthHandler( } } - if err := h.pingDB(req.Context()); err != nil { + if err := h.db.pingDB(req.Context()); err != nil { respond(err) return @@ -95,7 +98,7 @@ func (h *Headscale) RegisterWebAPI( vars := mux.Vars(req) nodeKeyStr, ok := vars["nkey"] - if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) { + if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) { log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -116,7 +119,7 @@ func (h *Headscale) RegisterWebAPI( // the template and log an error. var nodeKey key.NodePublic err := nodeKey.UnmarshalText( - []byte(NodePublicKeyEnsurePrefix(nodeKeyStr)), + []byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)), ) if !ok || nodeKeyStr == "" || err != nil { diff --git a/hscontrol/api_common.go b/hscontrol/api_common.go index 3dd65ac..f1b3fd8 100644 --- a/hscontrol/api_common.go +++ b/hscontrol/api_common.go @@ -3,6 +3,7 @@ package hscontrol import ( "time" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" ) @@ -15,7 +16,7 @@ func (h *Headscale) generateMapResponse( Str("func", "generateMapResponse"). Str("machine", mapRequest.Hostinfo.Hostname). Msg("Creating Map response") - node, err := h.toNode(*machine, h.cfg.BaseDomain, h.cfg.DNSConfig) + node, err := h.db.toNode(*machine, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). @@ -26,7 +27,7 @@ func (h *Headscale) generateMapResponse( return nil, err } - peers, err := h.getValidPeers(machine) + peers, err := h.db.getValidPeers(h.aclPolicy, h.aclRules, machine) if err != nil { log.Error(). Caller(). @@ -37,9 +38,9 @@ func (h *Headscale) generateMapResponse( return nil, err } - profiles := h.getMapResponseUserProfiles(*machine, peers) + profiles := h.db.getMapResponseUserProfiles(*machine, peers) - nodePeers, err := h.toNodes(peers, h.cfg.BaseDomain, h.cfg.DNSConfig) + nodePeers, err := h.db.toNodes(peers, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). @@ -107,7 +108,7 @@ func (h *Headscale) generateMapResponse( Str("func", "generateMapResponse"). Str("machine", mapRequest.Hostinfo.Hostname). // Interface("payload", resp). - Msgf("Generated map response: %s", tailMapResponseToString(resp)) + Msgf("Generated map response: %s", util.TailMapResponseToString(resp)) return &resp, nil } diff --git a/hscontrol/api_key.go b/hscontrol/api_key.go index 6382a33..bf2ccf3 100644 --- a/hscontrol/api_key.go +++ b/hscontrol/api_key.go @@ -1,11 +1,13 @@ package hscontrol import ( + "errors" "fmt" "strings" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" "golang.org/x/crypto/bcrypt" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -13,10 +15,10 @@ import ( const ( apiPrefixLength = 7 apiKeyLength = 32 - - ErrAPIKeyFailedToParse = Error("Failed to parse ApiKey") ) +var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") + // APIKey describes the datamodel for API keys used to remotely authenticate with // headscale. type APIKey struct { @@ -30,15 +32,15 @@ type APIKey struct { } // CreateAPIKey creates a new ApiKey in a user, and returns it. -func (h *Headscale) CreateAPIKey( +func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, ) (string, *APIKey, error) { - prefix, err := GenerateRandomStringURLSafe(apiPrefixLength) + prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) if err != nil { return "", nil, err } - toBeHashed, err := GenerateRandomStringURLSafe(apiKeyLength) + toBeHashed, err := util.GenerateRandomStringURLSafe(apiKeyLength) if err != nil { return "", nil, err } @@ -57,7 +59,7 @@ func (h *Headscale) CreateAPIKey( Expiration: expiration, } - if err := h.db.Save(&key).Error; err != nil { + if err := hsdb.db.Save(&key).Error; err != nil { return "", nil, fmt.Errorf("failed to save API key to database: %w", err) } @@ -65,9 +67,9 @@ func (h *Headscale) CreateAPIKey( } // ListAPIKeys returns the list of ApiKeys for a user. -func (h *Headscale) ListAPIKeys() ([]APIKey, error) { +func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) { keys := []APIKey{} - if err := h.db.Find(&keys).Error; err != nil { + if err := hsdb.db.Find(&keys).Error; err != nil { return nil, err } @@ -75,9 +77,9 @@ func (h *Headscale) ListAPIKeys() ([]APIKey, error) { } // GetAPIKey returns a ApiKey for a given key. -func (h *Headscale) GetAPIKey(prefix string) (*APIKey, error) { +func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) { key := APIKey{} - if result := h.db.First(&key, "prefix = ?", prefix); result.Error != nil { + if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error } @@ -85,9 +87,9 @@ func (h *Headscale) GetAPIKey(prefix string) (*APIKey, error) { } // GetAPIKeyByID returns a ApiKey for a given id. -func (h *Headscale) GetAPIKeyByID(id uint64) (*APIKey, error) { +func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) { key := APIKey{} - if result := h.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil { + if result := hsdb.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error } @@ -96,8 +98,8 @@ func (h *Headscale) GetAPIKeyByID(id uint64) (*APIKey, error) { // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // does not exist. -func (h *Headscale) DestroyAPIKey(key APIKey) error { - if result := h.db.Unscoped().Delete(key); result.Error != nil { +func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error { + if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { return result.Error } @@ -105,21 +107,21 @@ func (h *Headscale) DestroyAPIKey(key APIKey) error { } // ExpireAPIKey marks a ApiKey as expired. -func (h *Headscale) ExpireAPIKey(key *APIKey) error { - if err := h.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { +func (hsdb *HSDatabase) ExpireAPIKey(key *APIKey) error { + if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { return err } return nil } -func (h *Headscale) ValidateAPIKey(keyStr string) (bool, error) { +func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { prefix, hash, found := strings.Cut(keyStr, ".") if !found { return false, ErrAPIKeyFailedToParse } - key, err := h.GetAPIKey(prefix) + key, err := hsdb.GetAPIKey(prefix) if err != nil { return false, fmt.Errorf("failed to validate api key: %w", err) } diff --git a/hscontrol/api_key_test.go b/hscontrol/api_key_test.go index fd4fa00..007b5d1 100644 --- a/hscontrol/api_key_test.go +++ b/hscontrol/api_key_test.go @@ -7,7 +7,7 @@ import ( ) func (*Suite) TestCreateAPIKey(c *check.C) { - apiKeyStr, apiKey, err := app.CreateAPIKey(nil) + apiKeyStr, apiKey, err := app.db.CreateAPIKey(nil) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) @@ -16,74 +16,74 @@ func (*Suite) TestCreateAPIKey(c *check.C) { c.Assert(apiKey.Hash, check.NotNil) c.Assert(apiKeyStr, check.Not(check.Equals), "") - _, err = app.ListAPIKeys() + _, err = app.db.ListAPIKeys() c.Assert(err, check.IsNil) - keys, err := app.ListAPIKeys() + keys, err := app.db.ListAPIKeys() c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) } func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { - key, err := app.GetAPIKey("does-not-exist") + key, err := app.db.GetAPIKey("does-not-exist") c.Assert(err, check.NotNil) c.Assert(key, check.IsNil) } func (*Suite) TestValidateAPIKeyOk(c *check.C) { nowPlus2 := time.Now().Add(2 * time.Hour) - apiKeyStr, apiKey, err := app.CreateAPIKey(&nowPlus2) + apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.ValidateAPIKey(apiKeyStr) + valid, err := app.db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, true) } func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour) - apiKeyStr, apiKey, err := app.CreateAPIKey(&nowMinus2) + apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowMinus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.ValidateAPIKey(apiKeyStr) + valid, err := app.db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, false) now := time.Now() - apiKeyStrNow, apiKey, err := app.CreateAPIKey(&now) + apiKeyStrNow, apiKey, err := app.db.CreateAPIKey(&now) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - validNow, err := app.ValidateAPIKey(apiKeyStrNow) + validNow, err := app.db.ValidateAPIKey(apiKeyStrNow) c.Assert(err, check.IsNil) c.Assert(validNow, check.Equals, false) - validSilly, err := app.ValidateAPIKey("nota.validkey") + validSilly, err := app.db.ValidateAPIKey("nota.validkey") c.Assert(err, check.NotNil) c.Assert(validSilly, check.Equals, false) - validWithErr, err := app.ValidateAPIKey("produceerrorkey") + validWithErr, err := app.db.ValidateAPIKey("produceerrorkey") c.Assert(err, check.NotNil) c.Assert(validWithErr, check.Equals, false) } func (*Suite) TestExpireAPIKey(c *check.C) { nowPlus2 := time.Now().Add(2 * time.Hour) - apiKeyStr, apiKey, err := app.CreateAPIKey(&nowPlus2) + apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.ValidateAPIKey(apiKeyStr) + valid, err := app.db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, true) - err = app.ExpireAPIKey(apiKey) + err = app.db.ExpireAPIKey(apiKey) c.Assert(err, check.IsNil) c.Assert(apiKey.Expiration, check.NotNil) - notValid, err := app.ValidateAPIKey(apiKeyStr) + notValid, err := app.db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(notValid, check.Equals, false) } diff --git a/hscontrol/app.go b/hscontrol/app.go index b8dceba..38d4ec8 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -23,6 +23,7 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -41,24 +42,21 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" - "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" "tailscale.com/types/key" ) -const ( - errSTUNAddressNotSet = Error("STUN address not set") - errUnsupportedDatabase = Error("unsupported DB") - errUnsupportedLetsEncryptChallengeType = Error( +var ( + errSTUNAddressNotSet = errors.New("STUN address not set") + errUnsupportedDatabase = errors.New("unsupported DB") + errUnsupportedLetsEncryptChallengeType = errors.New( "unknown value for Lets Encrypt challenge type", ) ) const ( AuthPrefix = "Bearer " - Postgres = "postgres" - Sqlite = "sqlite3" updateInterval = 5000 HTTPReadTimeout = 30 * time.Second HTTPShutdownTimeout = 3 * time.Second @@ -75,7 +73,7 @@ const ( // Headscale represents the base app of the service. type Headscale struct { cfg *Config - db *gorm.DB + db *HSDatabase dbString string dbType string dbDebug bool @@ -96,10 +94,11 @@ type Headscale struct { registrationCache *cache.Cache - ipAllocationMutex sync.Mutex - shutdownChan chan struct{} pollNetMapStreamWG sync.WaitGroup + + stateUpdateChan chan struct{} + cancelStateUpdateChan chan struct{} } func NewHeadscale(cfg *Config) (*Headscale, error) { @@ -164,13 +163,27 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { registrationCache: registrationCache, pollNetMapStreamWG: sync.WaitGroup{}, lastStateChange: xsync.NewMapOf[time.Time](), + + stateUpdateChan: make(chan struct{}), + cancelStateUpdateChan: make(chan struct{}), } - err = app.initDB() + go app.watchStateChannel() + + db, err := NewHeadscaleDatabase( + cfg.DBtype, + dbString, + cfg.OIDC.StripEmaildomain, + app.dbDebug, + app.stateUpdateChan, + cfg.IPPrefixes, + cfg.BaseDomain) if err != nil { return nil, err } + app.db = db + if cfg.OIDC.Issuer != "" { err = app.initOIDC() if err != nil { @@ -231,7 +244,7 @@ func (h *Headscale) expireExpiredMachines(milliSeconds int64) { func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) { ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) for range ticker.C { - err := h.handlePrimarySubnetFailover() + err := h.db.handlePrimarySubnetFailover() if err != nil { log.Error().Err(err).Msg("failed to handle primary subnet failover") } @@ -239,7 +252,7 @@ func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) { } func (h *Headscale) expireEphemeralNodesWorker() { - users, err := h.ListUsers() + users, err := h.db.ListUsers() if err != nil { log.Error().Err(err).Msg("Error listing users") @@ -247,7 +260,7 @@ func (h *Headscale) expireEphemeralNodesWorker() { } for _, user := range users { - machines, err := h.ListMachinesByUser(user.Name) + machines, err := h.db.ListMachinesByUser(user.Name) if err != nil { log.Error(). Err(err). @@ -267,7 +280,7 @@ func (h *Headscale) expireEphemeralNodesWorker() { Str("machine", machine.Hostname). Msg("Ephemeral client removed from database") - err = h.db.Unscoped().Delete(machine).Error + err = h.db.db.Unscoped().Delete(machine).Error if err != nil { log.Error(). Err(err). @@ -284,7 +297,7 @@ func (h *Headscale) expireEphemeralNodesWorker() { } func (h *Headscale) expireExpiredMachinesWorker() { - users, err := h.ListUsers() + users, err := h.db.ListUsers() if err != nil { log.Error().Err(err).Msg("Error listing users") @@ -292,7 +305,7 @@ func (h *Headscale) expireExpiredMachinesWorker() { } for _, user := range users { - machines, err := h.ListMachinesByUser(user.Name) + machines, err := h.db.ListMachinesByUser(user.Name) if err != nil { log.Error(). Err(err). @@ -308,7 +321,7 @@ func (h *Headscale) expireExpiredMachinesWorker() { machine.Expiry.After(h.getLastStateChange(user)) { expiredFound = true - err := h.ExpireMachine(&machines[index]) + err := h.db.ExpireMachine(&machines[index]) if err != nil { log.Error(). Err(err). @@ -387,7 +400,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, ) } - valid, err := h.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix)) + valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix)) if err != nil { log.Error(). Caller(). @@ -438,7 +451,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler return } - valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) + valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) if err != nil { log.Error(). Caller(). @@ -597,7 +610,7 @@ func (h *Headscale) Serve() error { h.cfg.UnixSocket, []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithContextDialer(GrpcSocketDialer), + grpc.WithContextDialer(util.GrpcSocketDialer), }..., ) if err != nil { @@ -760,7 +773,7 @@ func (h *Headscale) Serve() error { // TODO(kradalby): Reload config on SIGHUP if h.cfg.ACL.PolicyPath != "" { - aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) + aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) err := h.LoadACLPolicyFromPath(aclPath) if err != nil { log.Error().Err(err).Msg("Failed to reload ACL policy") @@ -778,6 +791,7 @@ func (h *Headscale) Serve() error { Msg("Received signal to stop, shutting down gracefully") close(h.shutdownChan) + h.pollNetMapStreamWG.Wait() // Gracefully shut down servers @@ -806,8 +820,12 @@ func (h *Headscale) Serve() error { // Stop listening (and unlink the socket if unix type): socketListener.Close() + <-h.cancelStateUpdateChan + close(h.stateUpdateChan) + close(h.cancelStateUpdateChan) + // Close db connections - db, err := h.db.DB() + db, err := h.db.db.DB() if err != nil { log.Error().Err(err).Msg("Failed to get db handle") } @@ -905,12 +923,25 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { } } +// TODO(kradalby): baby steps, make this more robust. +func (h *Headscale) watchStateChannel() { + for { + select { + case <-h.stateUpdateChan: + h.setLastStateChangeToNow() + + case <-h.cancelStateUpdateChan: + return + } + } +} + func (h *Headscale) setLastStateChangeToNow() { var err error now := time.Now().UTC() - users, err := h.ListUsers() + users, err := h.db.ListUsers() if err != nil { log.Error(). Caller(). @@ -1002,7 +1033,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { } trimmedPrivateKey := strings.TrimSpace(string(privateKey)) - privateKeyEnsurePrefix := PrivateKeyEnsurePrefix(trimmedPrivateKey) + privateKeyEnsurePrefix := util.PrivateKeyEnsurePrefix(trimmedPrivateKey) var machineKey key.MachinePrivate if err = machineKey.UnmarshalText([]byte(privateKeyEnsurePrefix)); err != nil { diff --git a/hscontrol/app_test.go b/hscontrol/app_test.go index 7d3907d..1b4e91e 100644 --- a/hscontrol/app_test.go +++ b/hscontrol/app_test.go @@ -42,18 +42,32 @@ func (s *Suite) ResetDB(c *check.C) { IPPrefixes: []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, + OIDC: OIDCConfig{ + StripEmaildomain: false, + }, } + // TODO(kradalby): make this use NewHeadscale properly so it doesnt drift app = Headscale{ cfg: &cfg, dbType: "sqlite3", dbString: tmpDir + "/headscale_test.db", + + stateUpdateChan: make(chan struct{}), + cancelStateUpdateChan: make(chan struct{}), } - err = app.initDB() - if err != nil { - c.Fatal(err) - } - db, err := app.openDB() + + go app.watchStateChannel() + + db, err := NewHeadscaleDatabase( + app.dbType, + app.dbString, + cfg.OIDC.StripEmaildomain, + false, + app.stateUpdateChan, + cfg.IPPrefixes, + "", + ) if err != nil { c.Fatal(err) } diff --git a/hscontrol/config.go b/hscontrol/config.go index 0e83a1c..63deace 100644 --- a/hscontrol/config.go +++ b/hscontrol/config.go @@ -11,6 +11,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -271,15 +272,15 @@ func GetTLSConfig() TLSConfig { LetsEncrypt: LetsEncryptConfig{ Hostname: viper.GetString("tls_letsencrypt_hostname"), Listen: viper.GetString("tls_letsencrypt_listen"), - CacheDir: AbsolutePathFromConfigPath( + CacheDir: util.AbsolutePathFromConfigPath( viper.GetString("tls_letsencrypt_cache_dir"), ), ChallengeType: viper.GetString("tls_letsencrypt_challenge_type"), }, - CertPath: AbsolutePathFromConfigPath( + CertPath: util.AbsolutePathFromConfigPath( viper.GetString("tls_cert_path"), ), - KeyPath: AbsolutePathFromConfigPath( + KeyPath: util.AbsolutePathFromConfigPath( viper.GetString("tls_key_path"), ), } @@ -585,10 +586,10 @@ func GetHeadscaleConfig() (*Config, error) { DisableUpdateCheck: viper.GetBool("disable_check_updates"), IPPrefixes: prefixes, - PrivateKeyPath: AbsolutePathFromConfigPath( + PrivateKeyPath: util.AbsolutePathFromConfigPath( viper.GetString("private_key_path"), ), - NoisePrivateKeyPath: AbsolutePathFromConfigPath( + NoisePrivateKeyPath: util.AbsolutePathFromConfigPath( viper.GetString("noise.private_key_path"), ), BaseDomain: baseDomain, @@ -604,7 +605,7 @@ func GetHeadscaleConfig() (*Config, error) { ), DBtype: viper.GetString("db_type"), - DBpath: AbsolutePathFromConfigPath(viper.GetString("db_path")), + DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")), DBhost: viper.GetString("db_host"), DBport: viper.GetInt("db_port"), DBname: viper.GetString("db_name"), @@ -620,7 +621,7 @@ func GetHeadscaleConfig() (*Config, error) { ACMEURL: viper.GetString("acme_url"), UnixSocket: viper.GetString("unix_socket"), - UnixSocketPermission: GetFileMode("unix_socket_permission"), + UnixSocketPermission: util.GetFileMode("unix_socket_permission"), OIDC: OIDCConfig{ OnlyStartIfOIDCIsAvailable: viper.GetBool( diff --git a/hscontrol/db.go b/hscontrol/db.go index 14df4b3..e80a3c3 100644 --- a/hscontrol/db.go +++ b/hscontrol/db.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net/netip" + "sync" "time" "github.com/glebarez/sqlite" @@ -19,55 +20,90 @@ import ( const ( dbVersion = "1" + Postgres = "postgres" + Sqlite = "sqlite3" +) - errValueNotFound = Error("not found") - ErrCannotParsePrefix = Error("cannot parse prefix") +var ( + errValueNotFound = errors.New("not found") + ErrCannotParsePrefix = errors.New("cannot parse prefix") + errDatabaseNotSupported = errors.New("database type not supported") ) // KV is a key-value store in a psql table. For future use... +// TODO(kradalby): Is this used for anything? type KV struct { Key string Value string } -func (h *Headscale) initDB() error { - db, err := h.openDB() +type HSDatabase struct { + db *gorm.DB + notifyStateChan chan<- struct{} + + ipAllocationMutex sync.Mutex + + ipPrefixes []netip.Prefix + baseDomain string + stripEmailDomain bool +} + +// TODO(kradalby): assemble this struct from toptions or something typed +// rather than arguments. +func NewHeadscaleDatabase( + dbType, connectionAddr string, + stripEmailDomain, debug bool, + notifyStateChan chan<- struct{}, + ipPrefixes []netip.Prefix, + baseDomain string, +) (*HSDatabase, error) { + dbConn, err := openDB(dbType, connectionAddr, debug) if err != nil { - return err - } - h.db = db - - if h.dbType == Postgres { - db.Exec(`create extension if not exists "uuid-ossp";`) + return nil, err } - _ = db.Migrator().RenameTable("namespaces", "users") + db := HSDatabase{ + db: dbConn, + notifyStateChan: notifyStateChan, - err = db.AutoMigrate(&User{}) + ipPrefixes: ipPrefixes, + baseDomain: baseDomain, + stripEmailDomain: stripEmailDomain, + } + + log.Debug().Msgf("database %#v", dbConn) + + if dbType == Postgres { + dbConn.Exec(`create extension if not exists "uuid-ossp";`) + } + + _ = dbConn.Migrator().RenameTable("namespaces", "users") + + err = dbConn.AutoMigrate(User{}) if err != nil { - return err + return nil, err } - _ = db.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id") - _ = db.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id") + _ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id") + _ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id") - _ = db.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses") - _ = db.Migrator().RenameColumn(&Machine{}, "name", "hostname") + _ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses") + _ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname") // GivenName is used as the primary source of DNS names, make sure // the field is populated and normalized if it was not when the // machine was registered. - _ = db.Migrator().RenameColumn(&Machine{}, "nickname", "given_name") + _ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name") // If the Machine table has a column for registered, // find all occourences of "false" and drop them. Then // remove the column. - if db.Migrator().HasColumn(&Machine{}, "registered") { + if dbConn.Migrator().HasColumn(&Machine{}, "registered") { log.Info(). Msg(`Database has legacy "registered" column in machine, removing...`) machines := Machines{} - if err := h.db.Not("registered").Find(&machines).Error; err != nil { + if err := dbConn.Not("registered").Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") } @@ -76,7 +112,7 @@ func (h *Headscale) initDB() error { Str("machine", machine.Hostname). Str("machine_key", machine.MachineKey). Msg("Deleting unregistered machine") - if err := h.db.Delete(&Machine{}, machine.ID).Error; err != nil { + if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil { log.Error(). Err(err). Str("machine", machine.Hostname). @@ -85,18 +121,18 @@ func (h *Headscale) initDB() error { } } - err := db.Migrator().DropColumn(&Machine{}, "registered") + err := dbConn.Migrator().DropColumn(&Machine{}, "registered") if err != nil { log.Error().Err(err).Msg("Error dropping registered column") } } - err = db.AutoMigrate(&Route{}) + err = dbConn.AutoMigrate(&Route{}) if err != nil { - return err + return nil, err } - if db.Migrator().HasColumn(&Machine{}, "enabled_routes") { + if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") { log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...") type MachineAux struct { @@ -105,7 +141,7 @@ func (h *Headscale) initDB() error { } machinesAux := []MachineAux{} - err := db.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error + err := dbConn.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error if err != nil { log.Fatal().Err(err).Msg("Error accessing db") } @@ -120,7 +156,7 @@ func (h *Headscale) initDB() error { continue } - err = db.Preload("Machine"). + err = dbConn.Preload("Machine"). Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). First(&Route{}). Error @@ -138,7 +174,7 @@ func (h *Headscale) initDB() error { Enabled: true, Prefix: IPPrefix(prefix), } - if err := h.db.Create(&route).Error; err != nil { + if err := dbConn.Create(&route).Error; err != nil { log.Error().Err(err).Msg("Error creating route") } else { log.Info(). @@ -149,20 +185,20 @@ func (h *Headscale) initDB() error { } } - err = db.Migrator().DropColumn(&Machine{}, "enabled_routes") + err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes") if err != nil { log.Error().Err(err).Msg("Error dropping enabled_routes column") } } - err = db.AutoMigrate(&Machine{}) + err = dbConn.AutoMigrate(&Machine{}) if err != nil { - return err + return nil, err } - if db.Migrator().HasColumn(&Machine{}, "given_name") { + if dbConn.Migrator().HasColumn(&Machine{}, "given_name") { machines := Machines{} - if err := h.db.Find(&machines).Error; err != nil { + if err := dbConn.Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") } @@ -170,7 +206,7 @@ func (h *Headscale) initDB() error { if machine.GivenName == "" { normalizedHostname, err := NormalizeToFQDNRules( machine.Hostname, - h.cfg.OIDC.StripEmaildomain, + stripEmailDomain, ) if err != nil { log.Error(). @@ -180,7 +216,7 @@ func (h *Headscale) initDB() error { Msg("Failed to normalize machine hostname in DB migration") } - err = h.RenameMachine(&machines[item], normalizedHostname) + err = db.RenameMachine(&machines[item], normalizedHostname) if err != nil { log.Error(). Caller(). @@ -192,51 +228,51 @@ func (h *Headscale) initDB() error { } } - err = db.AutoMigrate(&KV{}) + err = dbConn.AutoMigrate(&KV{}) if err != nil { - return err + return nil, err } - err = db.AutoMigrate(&PreAuthKey{}) + err = dbConn.AutoMigrate(&PreAuthKey{}) if err != nil { - return err + return nil, err } - err = db.AutoMigrate(&PreAuthKeyACLTag{}) + err = dbConn.AutoMigrate(&PreAuthKeyACLTag{}) if err != nil { - return err + return nil, err } - _ = db.Migrator().DropTable("shared_machines") + _ = dbConn.Migrator().DropTable("shared_machines") - err = db.AutoMigrate(&APIKey{}) + err = dbConn.AutoMigrate(&APIKey{}) if err != nil { - return err + return nil, err } - err = h.setValue("db_version", dbVersion) + // TODO(kradalby): is this needed? + err = db.setValue("db_version", dbVersion) - return err + return &db, err } -func (h *Headscale) openDB() (*gorm.DB, error) { - var db *gorm.DB - var err error +func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { + log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database") - var log logger.Interface - if h.dbDebug { - log = logger.Default + var dbLogger logger.Interface + if debug { + dbLogger = logger.Default } else { - log = logger.Default.LogMode(logger.Silent) + dbLogger = logger.Default.LogMode(logger.Silent) } - switch h.dbType { + switch dbType { case Sqlite: - db, err = gorm.Open( - sqlite.Open(h.dbString+"?_synchronous=1&_journal_mode=WAL"), + db, err := gorm.Open( + sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, - Logger: log, + Logger: dbLogger, }, ) @@ -250,24 +286,30 @@ func (h *Headscale) openDB() (*gorm.DB, error) { sqlDB.SetMaxOpenConns(1) sqlDB.SetConnMaxIdleTime(time.Hour) + return db, err + case Postgres: - db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{ + return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, - Logger: log, + Logger: dbLogger, }) } - if err != nil { - return nil, err - } + return nil, fmt.Errorf( + "database of type %s is not supported: %w", + dbType, + errDatabaseNotSupported, + ) +} - return db, nil +func (hsdb *HSDatabase) notifyStateChange() { + hsdb.notifyStateChan <- struct{}{} } // getValue returns the value for the given key in KV. -func (h *Headscale) getValue(key string) (string, error) { +func (hsdb *HSDatabase) getValue(key string) (string, error) { var row KV - if result := h.db.First(&row, "key = ?", key); errors.Is( + if result := hsdb.db.First(&row, "key = ?", key); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -278,34 +320,34 @@ func (h *Headscale) getValue(key string) (string, error) { } // setValue sets value for the given key in KV. -func (h *Headscale) setValue(key string, value string) error { +func (hsdb *HSDatabase) setValue(key string, value string) error { keyValue := KV{ Key: key, Value: value, } - if _, err := h.getValue(key); err == nil { - h.db.Model(&keyValue).Where("key = ?", key).Update("value", value) + if _, err := hsdb.getValue(key); err == nil { + hsdb.db.Model(&keyValue).Where("key = ?", key).Update("value", value) return nil } - if err := h.db.Create(keyValue).Error; err != nil { + if err := hsdb.db.Create(keyValue).Error; err != nil { return fmt.Errorf("failed to create key value pair in the database: %w", err) } return nil } -func (h *Headscale) pingDB(ctx context.Context) error { +func (hsdb *HSDatabase) pingDB(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - db, err := h.db.DB() + sqlDB, err := hsdb.db.DB() if err != nil { return err } - return db.PingContext(ctx) + return sqlDB.PingContext(ctx) } // This is a "wrapper" type around tailscales diff --git a/hscontrol/dns_test.go b/hscontrol/dns_test.go index b825721..671a712 100644 --- a/hscontrol/dns_test.go +++ b/hscontrol/dns_test.go @@ -112,16 +112,16 @@ func (s *Suite) TestMagicDNSRootDomainsIPv6SingleMultiple(c *check.C) { } func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { - userShared1, err := app.CreateUser("shared1") + userShared1, err := app.db.CreateUser("shared1") c.Assert(err, check.IsNil) - userShared2, err := app.CreateUser("shared2") + userShared2, err := app.db.CreateUser("shared2") c.Assert(err, check.IsNil) - userShared3, err := app.CreateUser("shared3") + userShared3, err := app.db.CreateUser("shared3") c.Assert(err, check.IsNil) - preAuthKeyInShared1, err := app.CreatePreAuthKey( + preAuthKeyInShared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -130,7 +130,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyInShared2, err := app.CreatePreAuthKey( + preAuthKeyInShared2, err := app.db.CreatePreAuthKey( userShared2.Name, false, false, @@ -139,7 +139,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyInShared3, err := app.CreatePreAuthKey( + preAuthKeyInShared3, err := app.db.CreatePreAuthKey( userShared3.Name, false, false, @@ -148,7 +148,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - PreAuthKey2InShared1, err := app.CreatePreAuthKey( + PreAuthKey2InShared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -157,7 +157,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - _, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1") + _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") c.Assert(err, check.NotNil) machineInShared1 := &Machine{ @@ -172,9 +172,9 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), } - app.db.Save(machineInShared1) + app.db.db.Save(machineInShared1) - _, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname) + _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) c.Assert(err, check.IsNil) machineInShared2 := &Machine{ @@ -189,9 +189,9 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), } - app.db.Save(machineInShared2) + app.db.db.Save(machineInShared2) - _, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname) + _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) c.Assert(err, check.IsNil) machineInShared3 := &Machine{ @@ -206,9 +206,9 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), } - app.db.Save(machineInShared3) + app.db.db.Save(machineInShared3) - _, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname) + _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) c.Assert(err, check.IsNil) machine2InShared1 := &Machine{ @@ -223,7 +223,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, AuthKeyID: uint(PreAuthKey2InShared1.ID), } - app.db.Save(machine2InShared1) + app.db.db.Save(machine2InShared1) baseDomain := "foobar.headscale.net" dnsConfigOrig := tailcfg.DNSConfig{ @@ -232,7 +232,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Proxied: true, } - peersOfMachineInShared1, err := app.getPeers(machineInShared1) + peersOfMachineInShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) c.Assert(err, check.IsNil) dnsConfig := getMapResponseDNSConfig( @@ -259,16 +259,16 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { } func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { - userShared1, err := app.CreateUser("shared1") + userShared1, err := app.db.CreateUser("shared1") c.Assert(err, check.IsNil) - userShared2, err := app.CreateUser("shared2") + userShared2, err := app.db.CreateUser("shared2") c.Assert(err, check.IsNil) - userShared3, err := app.CreateUser("shared3") + userShared3, err := app.db.CreateUser("shared3") c.Assert(err, check.IsNil) - preAuthKeyInShared1, err := app.CreatePreAuthKey( + preAuthKeyInShared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -277,7 +277,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyInShared2, err := app.CreatePreAuthKey( + preAuthKeyInShared2, err := app.db.CreatePreAuthKey( userShared2.Name, false, false, @@ -286,7 +286,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyInShared3, err := app.CreatePreAuthKey( + preAuthKeyInShared3, err := app.db.CreatePreAuthKey( userShared3.Name, false, false, @@ -295,7 +295,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKey2InShared1, err := app.CreatePreAuthKey( + preAuthKey2InShared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -304,7 +304,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - _, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1") + _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") c.Assert(err, check.NotNil) machineInShared1 := &Machine{ @@ -319,9 +319,9 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), } - app.db.Save(machineInShared1) + app.db.db.Save(machineInShared1) - _, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname) + _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) c.Assert(err, check.IsNil) machineInShared2 := &Machine{ @@ -336,9 +336,9 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), } - app.db.Save(machineInShared2) + app.db.db.Save(machineInShared2) - _, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname) + _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) c.Assert(err, check.IsNil) machineInShared3 := &Machine{ @@ -353,9 +353,9 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), } - app.db.Save(machineInShared3) + app.db.db.Save(machineInShared3) - _, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname) + _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) c.Assert(err, check.IsNil) machine2InShared1 := &Machine{ @@ -370,7 +370,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, AuthKeyID: uint(preAuthKey2InShared1.ID), } - app.db.Save(machine2InShared1) + app.db.db.Save(machine2InShared1) baseDomain := "foobar.headscale.net" dnsConfigOrig := tailcfg.DNSConfig{ @@ -379,7 +379,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Proxied: false, } - peersOfMachine1Shared1, err := app.getPeers(machineInShared1) + peersOfMachine1Shared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) c.Assert(err, check.IsNil) dnsConfig := getMapResponseDNSConfig( diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index a65a380..4a26d08 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -8,6 +8,7 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -30,7 +31,7 @@ func (api headscaleV1APIServer) GetUser( ctx context.Context, request *v1.GetUserRequest, ) (*v1.GetUserResponse, error) { - user, err := api.h.GetUser(request.GetName()) + user, err := api.h.db.GetUser(request.GetName()) if err != nil { return nil, err } @@ -42,7 +43,7 @@ func (api headscaleV1APIServer) CreateUser( ctx context.Context, request *v1.CreateUserRequest, ) (*v1.CreateUserResponse, error) { - user, err := api.h.CreateUser(request.GetName()) + user, err := api.h.db.CreateUser(request.GetName()) if err != nil { return nil, err } @@ -54,12 +55,12 @@ func (api headscaleV1APIServer) RenameUser( ctx context.Context, request *v1.RenameUserRequest, ) (*v1.RenameUserResponse, error) { - err := api.h.RenameUser(request.GetOldName(), request.GetNewName()) + err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName()) if err != nil { return nil, err } - user, err := api.h.GetUser(request.GetNewName()) + user, err := api.h.db.GetUser(request.GetNewName()) if err != nil { return nil, err } @@ -71,7 +72,7 @@ func (api headscaleV1APIServer) DeleteUser( ctx context.Context, request *v1.DeleteUserRequest, ) (*v1.DeleteUserResponse, error) { - err := api.h.DestroyUser(request.GetName()) + err := api.h.db.DestroyUser(request.GetName()) if err != nil { return nil, err } @@ -83,7 +84,7 @@ func (api headscaleV1APIServer) ListUsers( ctx context.Context, request *v1.ListUsersRequest, ) (*v1.ListUsersResponse, error) { - users, err := api.h.ListUsers() + users, err := api.h.db.ListUsers() if err != nil { return nil, err } @@ -116,7 +117,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey( } } - preAuthKey, err := api.h.CreatePreAuthKey( + preAuthKey, err := api.h.db.CreatePreAuthKey( request.GetUser(), request.GetReusable(), request.GetEphemeral(), @@ -134,12 +135,12 @@ func (api headscaleV1APIServer) ExpirePreAuthKey( ctx context.Context, request *v1.ExpirePreAuthKeyRequest, ) (*v1.ExpirePreAuthKeyResponse, error) { - preAuthKey, err := api.h.GetPreAuthKey(request.GetUser(), request.Key) + preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key) if err != nil { return nil, err } - err = api.h.ExpirePreAuthKey(preAuthKey) + err = api.h.db.ExpirePreAuthKey(preAuthKey) if err != nil { return nil, err } @@ -151,7 +152,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys( ctx context.Context, request *v1.ListPreAuthKeysRequest, ) (*v1.ListPreAuthKeysResponse, error) { - preAuthKeys, err := api.h.ListPreAuthKeys(request.GetUser()) + preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser()) if err != nil { return nil, err } @@ -173,7 +174,8 @@ func (api headscaleV1APIServer) RegisterMachine( Str("node_key", request.GetKey()). Msg("Registering machine") - machine, err := api.h.RegisterMachineFromAuthCallback( + machine, err := api.h.db.RegisterMachineFromAuthCallback( + api.h.registrationCache, request.GetKey(), request.GetUser(), nil, @@ -190,7 +192,7 @@ func (api headscaleV1APIServer) GetMachine( ctx context.Context, request *v1.GetMachineRequest, ) (*v1.GetMachineResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } @@ -202,7 +204,7 @@ func (api headscaleV1APIServer) SetTags( ctx context.Context, request *v1.SetTagsRequest, ) (*v1.SetTagsResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } @@ -216,7 +218,7 @@ func (api headscaleV1APIServer) SetTags( } } - err = api.h.SetTags(machine, request.GetTags()) + err = api.h.db.SetTags(machine, request.GetTags(), api.h.UpdateACLRules) if err != nil { return &v1.SetTagsResponse{ Machine: nil, @@ -248,12 +250,12 @@ func (api headscaleV1APIServer) DeleteMachine( ctx context.Context, request *v1.DeleteMachineRequest, ) (*v1.DeleteMachineResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } - err = api.h.DeleteMachine( + err = api.h.db.DeleteMachine( machine, ) if err != nil { @@ -267,12 +269,12 @@ func (api headscaleV1APIServer) ExpireMachine( ctx context.Context, request *v1.ExpireMachineRequest, ) (*v1.ExpireMachineResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } - api.h.ExpireMachine( + api.h.db.ExpireMachine( machine, ) @@ -288,12 +290,12 @@ func (api headscaleV1APIServer) RenameMachine( ctx context.Context, request *v1.RenameMachineRequest, ) (*v1.RenameMachineResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } - err = api.h.RenameMachine( + err = api.h.db.RenameMachine( machine, request.GetNewName(), ) @@ -314,7 +316,7 @@ func (api headscaleV1APIServer) ListMachines( request *v1.ListMachinesRequest, ) (*v1.ListMachinesResponse, error) { if request.GetUser() != "" { - machines, err := api.h.ListMachinesByUser(request.GetUser()) + machines, err := api.h.db.ListMachinesByUser(request.GetUser()) if err != nil { return nil, err } @@ -327,7 +329,7 @@ func (api headscaleV1APIServer) ListMachines( return &v1.ListMachinesResponse{Machines: response}, nil } - machines, err := api.h.ListMachines() + machines, err := api.h.db.ListMachines() if err != nil { return nil, err } @@ -352,12 +354,12 @@ func (api headscaleV1APIServer) MoveMachine( ctx context.Context, request *v1.MoveMachineRequest, ) (*v1.MoveMachineResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } - err = api.h.SetMachineUser(machine, request.GetUser()) + err = api.h.db.SetMachineUser(machine, request.GetUser()) if err != nil { return nil, err } @@ -369,7 +371,7 @@ func (api headscaleV1APIServer) GetRoutes( ctx context.Context, request *v1.GetRoutesRequest, ) (*v1.GetRoutesResponse, error) { - routes, err := api.h.GetRoutes() + routes, err := api.h.db.GetRoutes() if err != nil { return nil, err } @@ -383,7 +385,7 @@ func (api headscaleV1APIServer) EnableRoute( ctx context.Context, request *v1.EnableRouteRequest, ) (*v1.EnableRouteResponse, error) { - err := api.h.EnableRoute(request.GetRouteId()) + err := api.h.db.EnableRoute(request.GetRouteId()) if err != nil { return nil, err } @@ -395,7 +397,7 @@ func (api headscaleV1APIServer) DisableRoute( ctx context.Context, request *v1.DisableRouteRequest, ) (*v1.DisableRouteResponse, error) { - err := api.h.DisableRoute(request.GetRouteId()) + err := api.h.db.DisableRoute(request.GetRouteId()) if err != nil { return nil, err } @@ -407,12 +409,12 @@ func (api headscaleV1APIServer) GetMachineRoutes( ctx context.Context, request *v1.GetMachineRoutesRequest, ) (*v1.GetMachineRoutesResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } - routes, err := api.h.GetMachineRoutes(machine) + routes, err := api.h.db.GetMachineRoutes(machine) if err != nil { return nil, err } @@ -426,7 +428,7 @@ func (api headscaleV1APIServer) DeleteRoute( ctx context.Context, request *v1.DeleteRouteRequest, ) (*v1.DeleteRouteResponse, error) { - err := api.h.DeleteRoute(request.GetRouteId()) + err := api.h.db.DeleteRoute(request.GetRouteId()) if err != nil { return nil, err } @@ -443,7 +445,7 @@ func (api headscaleV1APIServer) CreateApiKey( expiration = request.GetExpiration().AsTime() } - apiKey, _, err := api.h.CreateAPIKey( + apiKey, _, err := api.h.db.CreateAPIKey( &expiration, ) if err != nil { @@ -460,12 +462,12 @@ func (api headscaleV1APIServer) ExpireApiKey( var apiKey *APIKey var err error - apiKey, err = api.h.GetAPIKey(request.Prefix) + apiKey, err = api.h.db.GetAPIKey(request.Prefix) if err != nil { return nil, err } - err = api.h.ExpireAPIKey(apiKey) + err = api.h.db.ExpireAPIKey(apiKey) if err != nil { return nil, err } @@ -477,7 +479,7 @@ func (api headscaleV1APIServer) ListApiKeys( ctx context.Context, request *v1.ListApiKeysRequest, ) (*v1.ListApiKeysResponse, error) { - apiKeys, err := api.h.ListAPIKeys() + apiKeys, err := api.h.db.ListAPIKeys() if err != nil { return nil, err } @@ -495,12 +497,12 @@ func (api headscaleV1APIServer) DebugCreateMachine( ctx context.Context, request *v1.DebugCreateMachineRequest, ) (*v1.DebugCreateMachineResponse, error) { - user, err := api.h.GetUser(request.GetUser()) + user, err := api.h.db.GetUser(request.GetUser()) if err != nil { return nil, err } - routes, err := stringToIPPrefix(request.GetRoutes()) + routes, err := util.StringToIPPrefix(request.GetRoutes()) if err != nil { return nil, err } @@ -517,7 +519,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( Hostname: "DebugTestMachine", } - givenName, err := api.h.GenerateGivenName(request.GetKey(), request.GetName()) + givenName, err := api.h.db.GenerateGivenName(request.GetKey(), request.GetName()) if err != nil { return nil, err } @@ -542,7 +544,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( } api.h.registrationCache.Set( - NodePublicKeyStripPrefix(nodeKey), + util.NodePublicKeyStripPrefix(nodeKey), newMachine, registerCacheExpiration, ) diff --git a/hscontrol/machine.go b/hscontrol/machine.go index 9f04d8c..846112b 100644 --- a/hscontrol/machine.go +++ b/hscontrol/machine.go @@ -11,6 +11,8 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "github.com/samber/lo" "go4.org/netipx" @@ -21,23 +23,23 @@ import ( ) const ( - ErrMachineNotFound = Error("machine not found") - ErrMachineRouteIsNotAvailable = Error("route is not available on machine") - ErrMachineAddressesInvalid = Error("failed to parse machine addresses") - ErrMachineNotFoundRegistrationCache = Error( - "machine not found in registration cache", - ) - ErrCouldNotConvertMachineInterface = Error("failed to convert machine interface") - ErrHostnameTooLong = Error("Hostname too long") - ErrDifferentRegisteredUser = Error( - "machine was previously registered with a different user", - ) MachineGivenNameHashLength = 8 MachineGivenNameTrimSize = 2 + maxHostnameLength = 255 ) -const ( - maxHostnameLength = 255 +var ( + ErrMachineNotFound = errors.New("machine not found") + ErrMachineRouteIsNotAvailable = errors.New("route is not available on machine") + ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses") + ErrMachineNotFoundRegistrationCache = errors.New( + "machine not found in registration cache", + ) + ErrCouldNotConvertMachineInterface = errors.New("failed to convert machine interface") + ErrHostnameTooLong = errors.New("hostname too long") + ErrDifferentRegisteredUser = errors.New( + "machine was previously registered with a different user", + ) ) // Machine is a Headscale client. @@ -188,8 +190,10 @@ func (machine *Machine) canAccess(filter []tailcfg.FilterRule, machine2 *Machine // filterMachinesByACL wrapper function to not have devs pass around locks and maps // related to the application outside of tests. -func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines) Machines { - return filterMachinesByACL(currentMachine, peers, h.aclRules) +func (hsdb *HSDatabase) filterMachinesByACL( + aclRules []tailcfg.FilterRule, + currentMachine *Machine, peers Machines) Machines { + return filterMachinesByACL(currentMachine, peers, aclRules) } // filterMachinesByACL returns the list of peers authorized to be accessed from a given machine. @@ -213,14 +217,14 @@ func filterMachinesByACL( return result } -func (h *Headscale) ListPeers(machine *Machine) (Machines, error) { +func (hsdb *HSDatabase) ListPeers(machine *Machine) (Machines, error) { log.Trace(). Caller(). Str("machine", machine.Hostname). Msg("Finding direct peers") machines := Machines{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?", + if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?", machine.NodeKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") @@ -237,23 +241,27 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) { return machines, nil } -func (h *Headscale) getPeers(machine *Machine) (Machines, error) { +func (hsdb *HSDatabase) getPeers( + aclPolicy *ACLPolicy, + aclRules []tailcfg.FilterRule, + machine *Machine, +) (Machines, error) { var peers Machines var err error // If ACLs rules are defined, filter visible host list with the ACLs // else use the classic user scope - if h.aclPolicy != nil { + if aclPolicy != nil { var machines []Machine - machines, err = h.ListMachines() + machines, err = hsdb.ListMachines() if err != nil { log.Error().Err(err).Msg("Error retrieving list of machines") return Machines{}, err } - peers = h.filterMachinesByACL(machine, machines) + peers = hsdb.filterMachinesByACL(aclRules, machine, machines) } else { - peers, err = h.ListPeers(machine) + peers, err = hsdb.ListPeers(machine) if err != nil { log.Error(). Caller(). @@ -275,10 +283,14 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) { return peers, nil } -func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) { +func (hsdb *HSDatabase) getValidPeers( + aclPolicy *ACLPolicy, + aclRules []tailcfg.FilterRule, + machine *Machine, +) (Machines, error) { validPeers := make(Machines, 0) - peers, err := h.getPeers(machine) + peers, err := hsdb.getPeers(aclPolicy, aclRules, machine) if err != nil { return Machines{}, err } @@ -292,18 +304,18 @@ func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) { return validPeers, nil } -func (h *Headscale) ListMachines() ([]Machine, error) { +func (hsdb *HSDatabase) ListMachines() ([]Machine, error) { machines := []Machine{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil { + if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil { return nil, err } return machines, nil } -func (h *Headscale) ListMachinesByGivenName(givenName string) ([]Machine, error) { +func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) ([]Machine, error) { machines := []Machine{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil { + if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil { return nil, err } @@ -311,8 +323,8 @@ func (h *Headscale) ListMachinesByGivenName(givenName string) ([]Machine, error) } // GetMachine finds a Machine by name and user and returns the Machine struct. -func (h *Headscale) GetMachine(user string, name string) (*Machine, error) { - machines, err := h.ListMachinesByUser(user) +func (hsdb *HSDatabase) GetMachine(user string, name string) (*Machine, error) { + machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err } @@ -327,8 +339,8 @@ func (h *Headscale) GetMachine(user string, name string) (*Machine, error) { } // GetMachineByGivenName finds a Machine by given name and user and returns the Machine struct. -func (h *Headscale) GetMachineByGivenName(user string, givenName string) (*Machine, error) { - machines, err := h.ListMachinesByUser(user) +func (hsdb *HSDatabase) GetMachineByGivenName(user string, givenName string) (*Machine, error) { + machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err } @@ -343,9 +355,9 @@ func (h *Headscale) GetMachineByGivenName(user string, givenName string) (*Machi } // GetMachineByID finds a Machine by ID and returns the Machine struct. -func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { +func (hsdb *HSDatabase) GetMachineByID(id uint64) (*Machine, error) { m := Machine{} - if result := h.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil { + if result := hsdb.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil { return nil, result.Error } @@ -353,11 +365,11 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { } // GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct. -func (h *Headscale) GetMachineByMachineKey( +func (hsdb *HSDatabase) GetMachineByMachineKey( machineKey key.MachinePublic, ) (*Machine, error) { m := Machine{} - if result := h.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", MachinePublicKeyStripPrefix(machineKey)); result.Error != nil { + if result := hsdb.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil { return nil, result.Error } @@ -365,12 +377,12 @@ func (h *Headscale) GetMachineByMachineKey( } // GetMachineByNodeKey finds a Machine by its current NodeKey. -func (h *Headscale) GetMachineByNodeKey( +func (hsdb *HSDatabase) GetMachineByNodeKey( nodeKey key.NodePublic, ) (*Machine, error) { machine := Machine{} - if result := h.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?", - NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { + if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?", + util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { return nil, result.Error } @@ -378,14 +390,14 @@ func (h *Headscale) GetMachineByNodeKey( } // GetMachineByAnyNodeKey finds a Machine by its MachineKey, its current NodeKey or the old one, and returns the Machine struct. -func (h *Headscale) GetMachineByAnyKey( +func (hsdb *HSDatabase) GetMachineByAnyKey( machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, ) (*Machine, error) { machine := Machine{} - if result := h.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?", - MachinePublicKeyStripPrefix(machineKey), - NodePublicKeyStripPrefix(nodeKey), - NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil { + if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?", + util.MachinePublicKeyStripPrefix(machineKey), + util.NodePublicKeyStripPrefix(nodeKey), + util.NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil { return nil, result.Error } @@ -394,8 +406,8 @@ func (h *Headscale) GetMachineByAnyKey( // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database // and updates it with the latest data from the database. -func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { - if result := h.db.Find(machine).First(&machine); result.Error != nil { +func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *Machine) error { + if result := hsdb.db.Find(machine).First(&machine); result.Error != nil { return result.Error } @@ -403,20 +415,28 @@ func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { } // SetTags takes a Machine struct pointer and update the forced tags. -func (h *Headscale) SetTags(machine *Machine, tags []string) error { +func (hsdb *HSDatabase) SetTags( + machine *Machine, + tags []string, + // TODO(kradalby): This is a temporary measure to be able to detach the + // database completely from the global h. In the future, as part of this + // reorg, the rules will be generated on a per node basis, and not be prone + // to throwing error at save. + updateACL func() error) error { newTags := []string{} for _, tag := range tags { - if !contains(newTags, tag) { + if !util.StringOrPrefixListContains(newTags, tag) { newTags = append(newTags, tag) } } machine.ForcedTags = newTags - if err := h.UpdateACLRules(); err != nil && !errors.Is(err, errEmptyPolicy) { + if err := updateACL(); err != nil && !errors.Is(err, errEmptyPolicy) { return err } - h.setLastStateChangeToNow() - if err := h.db.Save(machine).Error; err != nil { + hsdb.notifyStateChange() + + if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf("failed to update tags for machine in the database: %w", err) } @@ -424,13 +444,13 @@ func (h *Headscale) SetTags(machine *Machine, tags []string) error { } // ExpireMachine takes a Machine struct and sets the expire field to now. -func (h *Headscale) ExpireMachine(machine *Machine) error { +func (hsdb *HSDatabase) ExpireMachine(machine *Machine) error { now := time.Now() machine.Expiry = &now - h.setLastStateChangeToNow() + hsdb.notifyStateChange() - if err := h.db.Save(machine).Error; err != nil { + if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf("failed to expire machine in the database: %w", err) } @@ -439,7 +459,7 @@ func (h *Headscale) ExpireMachine(machine *Machine) error { // RenameMachine takes a Machine struct and a new GivenName for the machines // and renames it. -func (h *Headscale) RenameMachine(machine *Machine, newName string) error { +func (hsdb *HSDatabase) RenameMachine(machine *Machine, newName string) error { err := CheckForFQDNRules( newName, ) @@ -455,9 +475,9 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error { } machine.GivenName = newName - h.setLastStateChangeToNow() + hsdb.notifyStateChange() - if err := h.db.Save(machine).Error; err != nil { + if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf("failed to rename machine in the database: %w", err) } @@ -465,15 +485,15 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error { } // RefreshMachine takes a Machine struct and sets the expire field to now. -func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error { +func (hsdb *HSDatabase) RefreshMachine(machine *Machine, expiry time.Time) error { now := time.Now() machine.LastSuccessfulUpdate = &now machine.Expiry = &expiry - h.setLastStateChangeToNow() + hsdb.notifyStateChange() - if err := h.db.Save(machine).Error; err != nil { + if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf( "failed to refresh machine (update expiration) in the database: %w", err, @@ -484,21 +504,21 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error { } // DeleteMachine softs deletes a Machine from the database. -func (h *Headscale) DeleteMachine(machine *Machine) error { - err := h.DeleteMachineRoutes(machine) +func (hsdb *HSDatabase) DeleteMachine(machine *Machine) error { + err := hsdb.DeleteMachineRoutes(machine) if err != nil { return err } - if err := h.db.Delete(&machine).Error; err != nil { + if err := hsdb.db.Delete(&machine).Error; err != nil { return err } return nil } -func (h *Headscale) TouchMachine(machine *Machine) error { - return h.db.Updates(Machine{ +func (hsdb *HSDatabase) TouchMachine(machine *Machine) error { + return hsdb.db.Updates(Machine{ ID: machine.ID, LastSeen: machine.LastSeen, LastSuccessfulUpdate: machine.LastSuccessfulUpdate, @@ -506,13 +526,13 @@ func (h *Headscale) TouchMachine(machine *Machine) error { } // HardDeleteMachine hard deletes a Machine from the database. -func (h *Headscale) HardDeleteMachine(machine *Machine) error { - err := h.DeleteMachineRoutes(machine) +func (hsdb *HSDatabase) HardDeleteMachine(machine *Machine) error { + err := hsdb.DeleteMachineRoutes(machine) if err != nil { return err } - if err := h.db.Unscoped().Delete(&machine).Error; err != nil { + if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil { return err } @@ -524,8 +544,8 @@ func (machine *Machine) GetHostInfo() tailcfg.Hostinfo { return tailcfg.Hostinfo(machine.HostInfo) } -func (h *Headscale) isOutdated(machine *Machine) bool { - if err := h.UpdateMachineFromDatabase(machine); err != nil { +func (hsdb *HSDatabase) isOutdated(machine *Machine, lastChange time.Time) bool { + if err := hsdb.UpdateMachineFromDatabase(machine); err != nil { // It does not seem meaningful to propagate this error as the end result // will have to be that the machine has to be considered outdated. return true @@ -536,7 +556,6 @@ func (h *Headscale) isOutdated(machine *Machine) bool { // TODO(kradalby): Only request updates from users where we can talk to nodes // This would mostly be for a bit of performance, and can be calculated based on // ACLs. - lastChange := h.getLastStateChange() lastUpdate := machine.CreatedAt if machine.LastSuccessfulUpdate != nil { lastUpdate = *machine.LastSuccessfulUpdate @@ -576,15 +595,16 @@ func (machines MachinesP) String() string { return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } -func (h *Headscale) toNodes( +func (hsdb *HSDatabase) toNodes( machines Machines, + aclPolicy *ACLPolicy, baseDomain string, dnsConfig *tailcfg.DNSConfig, ) ([]*tailcfg.Node, error) { nodes := make([]*tailcfg.Node, len(machines)) for index, machine := range machines { - node, err := h.toNode(machine, baseDomain, dnsConfig) + node, err := hsdb.toNode(machine, aclPolicy, baseDomain, dnsConfig) if err != nil { return nil, err } @@ -597,13 +617,14 @@ func (h *Headscale) toNodes( // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // as per the expected behaviour in the official SaaS. -func (h *Headscale) toNode( +func (hsdb *HSDatabase) toNode( machine Machine, + aclPolicy *ACLPolicy, baseDomain string, dnsConfig *tailcfg.DNSConfig, ) (*tailcfg.Node, error) { var nodeKey key.NodePublic - err := nodeKey.UnmarshalText([]byte(NodePublicKeyEnsurePrefix(machine.NodeKey))) + err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey))) if err != nil { log.Trace(). Caller(). @@ -617,7 +638,7 @@ func (h *Headscale) toNode( // MachineKey is only used in the legacy protocol if machine.MachineKey != "" { err = machineKey.UnmarshalText( - []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), + []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), ) if err != nil { return nil, fmt.Errorf("failed to parse machine public key: %w", err) @@ -627,7 +648,7 @@ func (h *Headscale) toNode( var discoKey key.DiscoPublic if machine.DiscoKey != "" { err := discoKey.UnmarshalText( - []byte(DiscoPublicKeyEnsurePrefix(machine.DiscoKey)), + []byte(util.DiscoPublicKeyEnsurePrefix(machine.DiscoKey)), ) if err != nil { return nil, fmt.Errorf("failed to parse disco public key: %w", err) @@ -646,13 +667,13 @@ func (h *Headscale) toNode( []netip.Prefix{}, addrs...) // we append the node own IP, as it is required by the clients - primaryRoutes, err := h.getMachinePrimaryRoutes(&machine) + primaryRoutes, err := hsdb.getMachinePrimaryRoutes(&machine) if err != nil { return nil, err } primaryPrefixes := Routes(primaryRoutes).toPrefixes() - machineRoutes, err := h.GetMachineRoutes(&machine) + machineRoutes, err := hsdb.GetMachineRoutes(&machine) if err != nil { return nil, err } @@ -699,13 +720,13 @@ func (h *Headscale) toNode( online := machine.isOnline() - tags, _ := getTags(h.aclPolicy, machine, h.cfg.OIDC.StripEmaildomain) + tags, _ := getTags(aclPolicy, machine, hsdb.stripEmailDomain) tags = lo.Uniq(append(tags, machine.ForcedTags...)) node := tailcfg.Node{ ID: tailcfg.NodeID(machine.ID), // this is the actual ID StableID: tailcfg.StableNodeID( - strconv.FormatUint(machine.ID, Base10), + strconv.FormatUint(machine.ID, util.Base10), ), // in headscale, unlike tailcontrol server, IDs are permanent Name: hostname, @@ -827,7 +848,8 @@ func getTags( return validTags, invalidTags } -func (h *Headscale) RegisterMachineFromAuthCallback( +func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( + cache *cache.Cache, nodeKeyStr string, userName string, machineExpiry *time.Time, @@ -846,9 +868,9 @@ func (h *Headscale) RegisterMachineFromAuthCallback( Str("expiresAt", fmt.Sprintf("%v", machineExpiry)). Msg("Registering machine from API/CLI or auth callback") - if machineInterface, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(nodeKey)); ok { + if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok { if registrationMachine, ok := machineInterface.(Machine); ok { - user, err := h.GetUser(userName) + user, err := hsdb.GetUser(userName) if err != nil { return nil, fmt.Errorf( "failed to find user in register machine from auth callback, %w", @@ -869,12 +891,12 @@ func (h *Headscale) RegisterMachineFromAuthCallback( registrationMachine.Expiry = machineExpiry } - machine, err := h.RegisterMachine( + machine, err := hsdb.RegisterMachine( registrationMachine, ) if err == nil { - h.registrationCache.Delete(nodeKeyStr) + cache.Delete(nodeKeyStr) } return machine, err @@ -887,7 +909,7 @@ func (h *Headscale) RegisterMachineFromAuthCallback( } // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. -func (h *Headscale) RegisterMachine(machine Machine, +func (hsdb *HSDatabase) RegisterMachine(machine Machine, ) (*Machine, error) { log.Debug(). Str("machine", machine.Hostname). @@ -900,7 +922,7 @@ func (h *Headscale) RegisterMachine(machine Machine, // so we store the machine.Expire and machine.Nodekey that has been set when // adding it to the registrationCache if len(machine.IPAddresses) > 0 { - if err := h.db.Save(&machine).Error; err != nil { + if err := hsdb.db.Save(&machine).Error; err != nil { return nil, fmt.Errorf("failed register existing machine in the database: %w", err) } @@ -915,10 +937,10 @@ func (h *Headscale) RegisterMachine(machine Machine, return &machine, nil } - h.ipAllocationMutex.Lock() - defer h.ipAllocationMutex.Unlock() + hsdb.ipAllocationMutex.Lock() + defer hsdb.ipAllocationMutex.Unlock() - ips, err := h.getAvailableIPs() + ips, err := hsdb.getAvailableIPs() if err != nil { log.Error(). Caller(). @@ -931,7 +953,7 @@ func (h *Headscale) RegisterMachine(machine Machine, machine.IPAddresses = ips - if err := h.db.Save(&machine).Error; err != nil { + if err := hsdb.db.Save(&machine).Error; err != nil { return nil, fmt.Errorf("failed register(save) machine in the database: %w", err) } @@ -945,10 +967,10 @@ func (h *Headscale) RegisterMachine(machine Machine, } // GetAdvertisedRoutes returns the routes that are be advertised by the given machine. -func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) { +func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) { routes := []Route{} - err := h.db. + err := hsdb.db. Preload("Machine"). Where("machine_id = ? AND advertised = ?", machine.ID, true).Find(&routes).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { @@ -970,10 +992,10 @@ func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error } // GetEnabledRoutes returns the routes that are enabled for the machine. -func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) { +func (hsdb *HSDatabase) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) { routes := []Route{} - err := h.db. + err := hsdb.db. Preload("Machine"). Where("machine_id = ? AND advertised = ? AND enabled = ?", machine.ID, true, true). Find(&routes).Error @@ -995,13 +1017,13 @@ func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) { return prefixes, nil } -func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool { +func (hsdb *HSDatabase) IsRoutesEnabled(machine *Machine, routeStr string) bool { route, err := netip.ParsePrefix(routeStr) if err != nil { return false } - enabledRoutes, err := h.GetEnabledRoutes(machine) + enabledRoutes, err := hsdb.GetEnabledRoutes(machine) if err != nil { log.Error().Err(err).Msg("Could not get enabled routes") @@ -1018,7 +1040,7 @@ func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool { } // enableRoutes enables new routes based on a list of new routes. -func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error { +func (hsdb *HSDatabase) enableRoutes(machine *Machine, routeStrs ...string) error { newRoutes := make([]netip.Prefix, len(routeStrs)) for index, routeStr := range routeStrs { route, err := netip.ParsePrefix(routeStr) @@ -1029,13 +1051,13 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error { newRoutes[index] = route } - advertisedRoutes, err := h.GetAdvertisedRoutes(machine) + advertisedRoutes, err := hsdb.GetAdvertisedRoutes(machine) if err != nil { return err } for _, newRoute := range newRoutes { - if !contains(advertisedRoutes, newRoute) { + if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) { return fmt.Errorf( "route (%s) is not available on node %s: %w", machine.Hostname, @@ -1047,7 +1069,7 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error { // Separate loop so we don't leave things in a half-updated state for _, prefix := range newRoutes { route := Route{} - err := h.db.Preload("Machine"). + err := hsdb.db.Preload("Machine"). Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). First(&route).Error if err == nil { @@ -1056,10 +1078,10 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error { // Mark already as primary if there is only this node offering this subnet // (and is not an exit route) if !route.isExitRoute() { - route.IsPrimary = h.isUniquePrefix(route) + route.IsPrimary = hsdb.isUniquePrefix(route) } - err = h.db.Save(&route).Error + err = hsdb.db.Save(&route).Error if err != nil { return fmt.Errorf("failed to enable route: %w", err) } @@ -1068,19 +1090,19 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error { } } - h.setLastStateChangeToNow() + hsdb.notifyStateChange() return nil } // EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy. -func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { +func (hsdb *HSDatabase) EnableAutoApprovedRoutes(aclPolicy *ACLPolicy, machine *Machine) error { if len(machine.IPAddresses) == 0 { return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs } routes := []Route{} - err := h.db. + err := hsdb.db. Preload("Machine"). Where("machine_id = ? AND advertised = true AND enabled = false", machine.ID). Find(&routes).Error @@ -1097,7 +1119,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { approvedRoutes := []Route{} for _, advertisedRoute := range routes { - routeApprovers, err := h.aclPolicy.AutoApprovers.GetRouteApprovers( + routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( netip.Prefix(advertisedRoute.Prefix), ) if err != nil { @@ -1113,7 +1135,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { if approvedAlias == machine.User.Name { approvedRoutes = append(approvedRoutes, advertisedRoute) } else { - approvedIps, err := h.aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, h.cfg.OIDC.StripEmaildomain) + approvedIps, err := aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, hsdb.stripEmailDomain) if err != nil { log.Err(err). Str("alias", approvedAlias). @@ -1132,7 +1154,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { for i, approvedRoute := range approvedRoutes { approvedRoutes[i].Enabled = true - err = h.db.Save(&approvedRoutes[i]).Error + err = hsdb.db.Save(&approvedRoutes[i]).Error if err != nil { log.Err(err). Str("approvedRoute", approvedRoute.String()). @@ -1146,10 +1168,10 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { return nil } -func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { +func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { normalizedHostname, err := NormalizeToFQDNRules( suppliedName, - h.cfg.OIDC.StripEmaildomain, + hsdb.stripEmailDomain, ) if err != nil { return "", err @@ -1162,7 +1184,7 @@ func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (s normalizedHostname = normalizedHostname[:trimmedHostnameLength] } - suffix, err := GenerateRandomStringDNSSafe(MachineGivenNameHashLength) + suffix, err := util.GenerateRandomStringDNSSafe(MachineGivenNameHashLength) if err != nil { return "", err } @@ -1173,21 +1195,21 @@ func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (s return normalizedHostname, nil } -func (h *Headscale) GenerateGivenName(machineKey string, suppliedName string) (string, error) { - givenName, err := h.generateGivenName(suppliedName, false) +func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) { + givenName, err := hsdb.generateGivenName(suppliedName, false) if err != nil { return "", err } // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - machines, err := h.ListMachinesByGivenName(givenName) + machines, err := hsdb.ListMachinesByGivenName(givenName) if err != nil { return "", err } for _, machine := range machines { if machine.MachineKey != machineKey && machine.GivenName == givenName { - postfixedName, err := h.generateGivenName(suppliedName, true) + postfixedName, err := hsdb.generateGivenName(suppliedName, true) if err != nil { return "", err } diff --git a/hscontrol/machine_test.go b/hscontrol/machine_test.go index 3f11da4..0e7d7de 100644 --- a/hscontrol/machine_test.go +++ b/hscontrol/machine_test.go @@ -9,19 +9,20 @@ import ( "testing" "time" + "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "tailscale.com/tailcfg" "tailscale.com/types/key" ) func (s *Suite) TestGetMachine(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := &Machine{ @@ -34,20 +35,20 @@ func (s *Suite) TestGetMachine(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(machine) + app.db.db.Save(machine) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) } func (s *Suite) TestGetMachineByID(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachineByID(0) + _, err = app.db.GetMachineByID(0) c.Assert(err, check.NotNil) machine := Machine{ @@ -60,20 +61,20 @@ func (s *Suite) TestGetMachineByID(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - _, err = app.GetMachineByID(0) + _, err = app.db.GetMachineByID(0) c.Assert(err, check.IsNil) } func (s *Suite) TestGetMachineByNodeKey(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachineByID(0) + _, err = app.db.GetMachineByID(0) c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -81,28 +82,28 @@ func (s *Suite) TestGetMachineByNodeKey(c *check.C) { machine := Machine{ ID: 0, - MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), + MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), DiscoKey: "faa", Hostname: "testmachine", UserID: user.ID, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - _, err = app.GetMachineByNodeKey(nodeKey.Public()) + _, err = app.db.GetMachineByNodeKey(nodeKey.Public()) c.Assert(err, check.IsNil) } func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachineByID(0) + _, err = app.db.GetMachineByID(0) c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -112,22 +113,22 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { machine := Machine{ ID: 0, - MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), + MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), DiscoKey: "faa", Hostname: "testmachine", UserID: user.ID, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - _, err = app.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) + _, err = app.db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) c.Assert(err, check.IsNil) } func (s *Suite) TestDeleteMachine(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) machine := Machine{ ID: 0, @@ -139,17 +140,17 @@ func (s *Suite) TestDeleteMachine(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(1), } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.DeleteMachine(&machine) + err = app.db.DeleteMachine(&machine) c.Assert(err, check.IsNil) - _, err = app.GetMachine(user.Name, "testmachine") + _, err = app.db.GetMachine(user.Name, "testmachine") c.Assert(err, check.NotNil) } func (s *Suite) TestHardDeleteMachine(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) machine := Machine{ ID: 0, @@ -161,23 +162,23 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(1), } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.HardDeleteMachine(&machine) + err = app.db.HardDeleteMachine(&machine) c.Assert(err, check.IsNil) - _, err = app.GetMachine(user.Name, "testmachine3") + _, err = app.db.GetMachine(user.Name, "testmachine3") c.Assert(err, check.NotNil) } func (s *Suite) TestListPeers(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachineByID(0) + _, err = app.db.GetMachineByID(0) c.Assert(err, check.NotNil) for index := 0; index <= 10; index++ { @@ -191,13 +192,13 @@ func (s *Suite) TestListPeers(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) } - machine0ByID, err := app.GetMachineByID(0) + machine0ByID, err := app.db.GetMachineByID(0) c.Assert(err, check.IsNil) - peersOfMachine0, err := app.ListPeers(machine0ByID) + peersOfMachine0, err := app.db.ListPeers(machine0ByID) c.Assert(err, check.IsNil) c.Assert(len(peersOfMachine0), check.Equals, 9) @@ -215,14 +216,14 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { stor := make([]base, 0) for _, name := range []string{"test", "admin"} { - user, err := app.CreateUser(name) + user, err := app.db.CreateUser(name) c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) stor = append(stor, base{user, pak}) } - _, err := app.GetMachineByID(0) + _, err := app.db.GetMachineByID(0) c.Assert(err, check.NotNil) for index := 0; index <= 10; index++ { @@ -239,7 +240,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(stor[index%2].key.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) } app.aclPolicy = &ACLPolicy{ @@ -266,19 +267,19 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { err = app.UpdateACLRules() c.Assert(err, check.IsNil) - adminMachine, err := app.GetMachineByID(1) + adminMachine, err := app.db.GetMachineByID(1) c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User) c.Assert(err, check.IsNil) - testMachine, err := app.GetMachineByID(2) + testMachine, err := app.db.GetMachineByID(2) c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User) c.Assert(err, check.IsNil) - machines, err := app.ListMachines() + machines, err := app.db.ListMachines() c.Assert(err, check.IsNil) - peersOfTestMachine := app.filterMachinesByACL(testMachine, machines) - peersOfAdminMachine := app.filterMachinesByACL(adminMachine, machines) + peersOfTestMachine := app.db.filterMachinesByACL(app.aclRules, testMachine, machines) + peersOfAdminMachine := app.db.filterMachinesByACL(app.aclRules, adminMachine, machines) c.Log(peersOfTestMachine) c.Assert(len(peersOfTestMachine), check.Equals, 9) @@ -294,13 +295,13 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { } func (s *Suite) TestExpireMachine(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := &Machine{ @@ -314,15 +315,15 @@ func (s *Suite) TestExpireMachine(c *check.C) { AuthKeyID: uint(pak.ID), Expiry: &time.Time{}, } - app.db.Save(machine) + app.db.db.Save(machine) - machineFromDB, err := app.GetMachine("test", "testmachine") + machineFromDB, err := app.db.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) c.Assert(machineFromDB, check.NotNil) c.Assert(machineFromDB.isExpired(), check.Equals, false) - err = app.ExpireMachine(machineFromDB) + err = app.db.ExpireMachine(machineFromDB) c.Assert(err, check.IsNil) c.Assert(machineFromDB.isExpired(), check.Equals, true) @@ -350,13 +351,13 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { } func (s *Suite) TestGenerateGivenName(c *check.C) { - user1, err := app.CreateUser("user-1") + user1, err := app.db.CreateUser("user-1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user1.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user1.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user-1", "testmachine") + _, err = app.db.GetMachine("user-1", "testmachine") c.Assert(err, check.NotNil) machine := &Machine{ @@ -370,37 +371,37 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(machine) + app.db.db.Save(machine) - givenName, err := app.GenerateGivenName("machine-key-2", "hostname-2") + givenName, err := app.db.GenerateGivenName("machine-key-2", "hostname-2") comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict") c.Assert(err, check.IsNil, comment) c.Assert(givenName, check.Equals, "hostname-2", comment) - givenName, err = app.GenerateGivenName("machine-key-1", "hostname-1") + givenName, err = app.db.GenerateGivenName("machine-key-1", "hostname-1") comment = check.Commentf("Same user, same machine, same hostname, no conflict") c.Assert(err, check.IsNil, comment) c.Assert(givenName, check.Equals, "hostname-1", comment) - givenName, err = app.GenerateGivenName("machine-key-2", "hostname-1") + givenName, err = app.db.GenerateGivenName("machine-key-2", "hostname-1") comment = check.Commentf("Same user, unique machines, same hostname, conflict") c.Assert(err, check.IsNil, comment) c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) - givenName, err = app.GenerateGivenName("machine-key-2", "hostname-1") + givenName, err = app.db.GenerateGivenName("machine-key-2", "hostname-1") comment = check.Commentf("Unique users, unique machines, same hostname, conflict") c.Assert(err, check.IsNil, comment) c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) } func (s *Suite) TestSetTags(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := &Machine{ @@ -413,21 +414,21 @@ func (s *Suite) TestSetTags(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(machine) + app.db.db.Save(machine) // assign simple tags sTags := []string{"tag:test", "tag:foo"} - err = app.SetTags(machine, sTags) + err = app.db.SetTags(machine, sTags, app.UpdateACLRules) c.Assert(err, check.IsNil) - machine, err = app.GetMachine("test", "testmachine") + machine, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) c.Assert(machine.ForcedTags, check.DeepEquals, StringList(sTags)) // assign duplicat tags, expect no errors but no doubles in DB eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} - err = app.SetTags(machine, eTags) + err = app.db.SetTags(machine, eTags, app.UpdateACLRules) c.Assert(err, check.IsNil) - machine, err = app.GetMachine("test", "testmachine") + machine, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) c.Assert( machine.ForcedTags, @@ -562,7 +563,7 @@ func Test_getTags(t *testing.T) { test.args.stripEmailDomain, ) for _, valid := range gotValid { - if !contains(test.wantValid, valid) { + if !util.StringOrPrefixListContains(test.wantValid, valid) { t.Errorf( "valids: getTags() = %v, want %v", gotValid, @@ -573,7 +574,7 @@ func Test_getTags(t *testing.T) { } } for _, invalid := range gotInvalid { - if !contains(test.wantInvalid, invalid) { + if !util.StringOrPrefixListContains(test.wantInvalid, invalid) { t.Errorf( "invalids: getTags() = %v, want %v", gotInvalid, @@ -1061,19 +1062,15 @@ func TestHeadscale_generateGivenName(t *testing.T) { } tests := []struct { name string - h *Headscale + db *HSDatabase args args want *regexp.Regexp wantErr bool }{ { name: "simple machine name generation", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "testmachine", @@ -1084,12 +1081,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 53 chars", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", @@ -1100,12 +1093,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", @@ -1116,12 +1105,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 64 chars", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", @@ -1132,12 +1117,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 73 chars", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", @@ -1148,12 +1129,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with random suffix", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "test", @@ -1164,12 +1141,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars with random suffix", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", @@ -1181,7 +1154,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.h.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) + got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) if (err != nil) != tt.wantErr { t.Errorf( "Headscale.GenerateGivenName() error = %v, wantErr %v", @@ -1214,35 +1187,35 @@ func TestHeadscale_generateGivenName(t *testing.T) { func (s *Suite) TestAutoApproveRoutes(c *check.C) { acl := []byte(` { - "tagOwners": { - "tag:exit": ["test"], - }, + "tagOwners": { + "tag:exit": ["test"], + }, - "groups": { - "group:test": ["test"] - }, + "groups": { + "group:test": ["test"] + }, - "acls": [ - {"action": "accept", "users": ["*"], "ports": ["*:*"]}, - ], + "acls": [ + {"action": "accept", "users": ["*"], "ports": ["*:*"]}, + ], - "autoApprovers": { - "exitNode": ["tag:exit"], - "routes": { - "10.10.0.0/16": ["group:test"], - "10.11.0.0/16": ["test"], - } - } + "autoApprovers": { + "exitNode": ["tag:exit"], + "routes": { + "10.10.0.0/16": ["group:test"], + "10.11.0.0/16": ["test"], + } + } } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) nodeKey := key.NewNode() @@ -1255,7 +1228,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { machine := Machine{ ID: 0, MachineKey: "foo", - NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), DiscoKey: "faa", Hostname: "test", UserID: user.ID, @@ -1268,18 +1241,18 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.processMachineRoutes(&machine) + err = app.db.processMachineRoutes(&machine) c.Assert(err, check.IsNil) - machine0ByID, err := app.GetMachineByID(0) + machine0ByID, err := app.db.GetMachineByID(0) c.Assert(err, check.IsNil) - err = app.EnableAutoApprovedRoutes(machine0ByID) + err = app.db.EnableAutoApprovedRoutes(app.aclPolicy, machine0ByID) c.Assert(err, check.IsNil) - enabledRoutes, err := app.GetEnabledRoutes(machine0ByID) + enabledRoutes, err := app.db.GetEnabledRoutes(machine0ByID) c.Assert(err, check.IsNil) c.Assert(enabledRoutes, check.HasLen, 3) } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 332ce09..c666594 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -14,6 +14,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" "tailscale.com/types/key" @@ -21,16 +22,22 @@ import ( const ( randomByteSize = 16 +) - errEmptyOIDCCallbackParams = Error("empty OIDC callback params") - errNoOIDCIDToken = Error("could not extract ID Token for OIDC callback") - errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain") - errOIDCAllowedGroups = Error("authenticated principal is not in any allowed group") - errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user") - errOIDCInvalidMachineState = Error( +var ( + errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params") + errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback") + errOIDCAllowedDomains = errors.New( + "authenticated principal does not match any allowed domain", + ) + errOIDCAllowedGroups = errors.New("authenticated principal is not in any allowed group") + errOIDCAllowedUsers = errors.New( + "authenticated principal does not match any allowed user", + ) + errOIDCInvalidMachineState = errors.New( "requested machine state key expired before authorisation completed", ) - errOIDCNodeKeyMissing = Error("could not get node key from cache") + errOIDCNodeKeyMissing = errors.New("could not get node key from cache") ) type IDTokenClaims struct { @@ -94,7 +101,7 @@ func (h *Headscale) RegisterOIDC( Bool("ok", ok). Msg("Received oidc register call") - if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) { + if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) { log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -115,7 +122,7 @@ func (h *Headscale) RegisterOIDC( // the template and log an error. var nodeKey key.NodePublic err := nodeKey.UnmarshalText( - []byte(NodePublicKeyEnsurePrefix(nodeKeyStr)), + []byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)), ) if !ok || nodeKeyStr == "" || err != nil { @@ -149,7 +156,11 @@ func (h *Headscale) RegisterOIDC( stateStr := hex.EncodeToString(randomBlob)[:32] // place the node key into the state cache, so it can be retrieved later - h.registrationCache.Set(stateStr, NodePublicKeyStripPrefix(nodeKey), registerCacheExpiration) + h.registrationCache.Set( + stateStr, + util.NodePublicKeyStripPrefix(nodeKey), + registerCacheExpiration, + ) // Add any extra parameter provided in the configuration to the Authorize Endpoint request extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) @@ -406,7 +417,7 @@ func validateOIDCAllowedDomains( ) error { if len(allowedDomains) > 0 { if at := strings.LastIndex(claims.Email, "@"); at < 0 || - !IsStringInSlice(allowedDomains, claims.Email[at+1:]) { + !util.IsStringInSlice(allowedDomains, claims.Email[at+1:]) { log.Error().Msg("authenticated principal does not match any allowed domain") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) @@ -436,7 +447,7 @@ func validateOIDCAllowedGroups( ) error { if len(allowedGroups) > 0 { for _, group := range allowedGroups { - if IsStringInSlice(claims.Groups, group) { + if util.IsStringInSlice(claims.Groups, group) { return nil } } @@ -466,7 +477,7 @@ func validateOIDCAllowedUsers( claims *IDTokenClaims, ) error { if len(allowedUsers) > 0 && - !IsStringInSlice(allowedUsers, claims.Email) { + !util.IsStringInSlice(allowedUsers, claims.Email) { log.Error().Msg("authenticated principal does not match any allowed user") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) @@ -531,7 +542,7 @@ func (h *Headscale) validateMachineForOIDCCallback( } err := nodeKey.UnmarshalText( - []byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)), + []byte(util.NodePublicKeyEnsurePrefix(nodeKeyFromCache)), ) if err != nil { log.Error(). @@ -555,7 +566,7 @@ func (h *Headscale) validateMachineForOIDCCallback( // The error is not important, because if it does not // exist, then this is a new machine and we will move // on to registration. - machine, _ := h.GetMachineByNodeKey(nodeKey) + machine, _ := h.db.GetMachineByNodeKey(nodeKey) if machine != nil { log.Trace(). @@ -563,7 +574,7 @@ func (h *Headscale) validateMachineForOIDCCallback( Str("machine", machine.Hostname). Msg("machine already registered, reauthenticating") - err := h.RefreshMachine(machine, expiry) + err := h.db.RefreshMachine(machine, expiry) if err != nil { log.Error(). Caller(). @@ -653,9 +664,9 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback( writer http.ResponseWriter, userName string, ) (*User, error) { - user, err := h.GetUser(userName) + user, err := h.db.GetUser(userName) if errors.Is(err, ErrUserNotFound) { - user, err = h.CreateUser(userName) + user, err = h.db.CreateUser(userName) if err != nil { log.Error(). @@ -702,7 +713,9 @@ func (h *Headscale) registerMachineForOIDCCallback( nodeKey *key.NodePublic, expiry time.Time, ) error { - if _, err := h.RegisterMachineFromAuthCallback( + if _, err := h.db.RegisterMachineFromAuthCallback( + // TODO(kradalby): find a better way to use the cache across modules + h.registrationCache, nodeKey.String(), user.Name, &expiry, diff --git a/hscontrol/preauth_keys.go b/hscontrol/preauth_keys.go index 6cff90b..1956762 100644 --- a/hscontrol/preauth_keys.go +++ b/hscontrol/preauth_keys.go @@ -10,16 +10,17 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" ) -const ( - ErrPreAuthKeyNotFound = Error("AuthKey not found") - ErrPreAuthKeyExpired = Error("AuthKey expired") - ErrSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used") - ErrUserMismatch = Error("user mismatch") - ErrPreAuthKeyACLTagInvalid = Error("AuthKey tag is invalid") +var ( + ErrPreAuthKeyNotFound = errors.New("AuthKey not found") + ErrPreAuthKeyExpired = errors.New("AuthKey expired") + ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used") + ErrUserMismatch = errors.New("user mismatch") + ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") ) // PreAuthKey describes a pre-authorization key usable in a particular user. @@ -45,26 +46,30 @@ type PreAuthKeyACLTag struct { } // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. -func (h *Headscale) CreatePreAuthKey( +func (hsdb *HSDatabase) CreatePreAuthKey( userName string, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string, ) (*PreAuthKey, error) { - user, err := h.GetUser(userName) + user, err := hsdb.GetUser(userName) if err != nil { return nil, err } for _, tag := range aclTags { if !strings.HasPrefix(tag, "tag:") { - return nil, fmt.Errorf("%w: '%s' did not begin with 'tag:'", ErrPreAuthKeyACLTagInvalid, tag) + return nil, fmt.Errorf( + "%w: '%s' did not begin with 'tag:'", + ErrPreAuthKeyACLTagInvalid, + tag, + ) } } now := time.Now().UTC() - kstr, err := h.generateKey() + kstr, err := hsdb.generateKey() if err != nil { return nil, err } @@ -79,7 +84,7 @@ func (h *Headscale) CreatePreAuthKey( Expiration: expiration, } - err = h.db.Transaction(func(db *gorm.DB) error { + err = hsdb.db.Transaction(func(db *gorm.DB) error { if err := db.Save(&key).Error; err != nil { return fmt.Errorf("failed to create key in the database: %w", err) } @@ -111,14 +116,14 @@ func (h *Headscale) CreatePreAuthKey( } // ListPreAuthKeys returns the list of PreAuthKeys for a user. -func (h *Headscale) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { - user, err := h.GetUser(userName) +func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { + user, err := hsdb.GetUser(userName) if err != nil { return nil, err } keys := []PreAuthKey{} - if err := h.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { + if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { return nil, err } @@ -126,8 +131,8 @@ func (h *Headscale) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { } // GetPreAuthKey returns a PreAuthKey for a given key. -func (h *Headscale) GetPreAuthKey(user string, key string) (*PreAuthKey, error) { - pak, err := h.checkKeyValidity(key) +func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, error) { + pak, err := hsdb.checkKeyValidity(key) if err != nil { return nil, err } @@ -141,8 +146,8 @@ func (h *Headscale) GetPreAuthKey(user string, key string) (*PreAuthKey, error) // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. -func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error { - return h.db.Transaction(func(db *gorm.DB) error { +func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error { + return hsdb.db.Transaction(func(db *gorm.DB) error { if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil { return result.Error } @@ -156,8 +161,8 @@ func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error { } // MarkExpirePreAuthKey marks a PreAuthKey as expired. -func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error { - if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { +func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error { + if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -165,9 +170,9 @@ func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error { } // UsePreAuthKey marks a PreAuthKey as used. -func (h *Headscale) UsePreAuthKey(k *PreAuthKey) error { +func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error { k.Used = true - if err := h.db.Save(k).Error; err != nil { + if err := hsdb.db.Save(k).Error; err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) } @@ -176,9 +181,9 @@ func (h *Headscale) UsePreAuthKey(k *PreAuthKey) error { // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. -func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { +func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) { pak := PreAuthKey{} - if result := h.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( + if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -194,7 +199,7 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { } machines := []Machine{} - if err := h.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { + if err := hsdb.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { return nil, err } @@ -205,7 +210,7 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { return &pak, nil } -func (h *Headscale) generateKey() (string, error) { +func (hsdb *HSDatabase) generateKey() (string, error) { size := 24 bytes := make([]byte, size) if _, err := rand.Read(bytes); err != nil { @@ -218,7 +223,7 @@ func (h *Headscale) generateKey() (string, error) { func (key *PreAuthKey) toProto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ User: key.User.Name, - Id: strconv.FormatUint(key.ID, Base10), + Id: strconv.FormatUint(key.ID, util.Base10), Key: key.Key, Ephemeral: key.Ephemeral, Reusable: key.Reusable, diff --git a/hscontrol/preauth_keys_test.go b/hscontrol/preauth_keys_test.go index bd383cf..a85a6c6 100644 --- a/hscontrol/preauth_keys_test.go +++ b/hscontrol/preauth_keys_test.go @@ -7,14 +7,14 @@ import ( ) func (*Suite) TestCreatePreAuthKey(c *check.C) { - _, err := app.CreatePreAuthKey("bogus", true, false, nil, nil) + _, err := app.db.CreatePreAuthKey("bogus", true, false, nil, nil) c.Assert(err, check.NotNil) - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - key, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil) + key, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) // Did we get a valid key? @@ -24,10 +24,10 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { // Make sure the User association is populated c.Assert(key.User.Name, check.Equals, user.Name) - _, err = app.ListPreAuthKeys("bogus") + _, err = app.db.ListPreAuthKeys("bogus") c.Assert(err, check.NotNil) - keys, err := app.ListPreAuthKeys(user.Name) + keys, err := app.db.ListPreAuthKeys(user.Name) c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) @@ -36,41 +36,41 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { } func (*Suite) TestExpiredPreAuthKey(c *check.C) { - user, err := app.CreateUser("test2") + user, err := app.db.CreateUser("test2") c.Assert(err, check.IsNil) now := time.Now() - pak, err := app.CreatePreAuthKey(user.Name, true, false, &now, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, true, false, &now, nil) c.Assert(err, check.IsNil) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, ErrPreAuthKeyExpired) c.Assert(key, check.IsNil) } func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { - key, err := app.checkKeyValidity("potatoKey") + key, err := app.db.checkKeyValidity("potatoKey") c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) c.Assert(key, check.IsNil) } func (*Suite) TestValidateKeyOk(c *check.C) { - user, err := app.CreateUser("test3") + user, err := app.db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestAlreadyUsedKey(c *check.C) { - user, err := app.CreateUser("test4") + user, err := app.db.CreateUser("test4") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) machine := Machine{ @@ -83,18 +83,18 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) c.Assert(key, check.IsNil) } func (*Suite) TestReusableBeingUsedKey(c *check.C) { - user, err := app.CreateUser("test5") + user, err := app.db.CreateUser("test5") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) machine := Machine{ @@ -107,30 +107,30 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { - user, err := app.CreateUser("test6") + user, err := app.db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestEphemeralKey(c *check.C) { - user, err := app.CreateUser("test7") + user, err := app.db.CreateUser("test7") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, true, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, true, nil, nil) c.Assert(err, check.IsNil) now := time.Now() @@ -145,65 +145,65 @@ func (*Suite) TestEphemeralKey(c *check.C) { LastSeen: &now, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - _, err = app.checkKeyValidity(pak.Key) + _, err = app.db.checkKeyValidity(pak.Key) // Ephemeral keys are by definition reusable c.Assert(err, check.IsNil) - _, err = app.GetMachine("test7", "testest") + _, err = app.db.GetMachine("test7", "testest") c.Assert(err, check.IsNil) app.expireEphemeralNodesWorker() // The machine record should have been deleted - _, err = app.GetMachine("test7", "testest") + _, err = app.db.GetMachine("test7", "testest") c.Assert(err, check.NotNil) } func (*Suite) TestExpirePreauthKey(c *check.C) { - user, err := app.CreateUser("test3") + user, err := app.db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.IsNil) - err = app.ExpirePreAuthKey(pak) + err = app.db.ExpirePreAuthKey(pak) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.NotNil) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, ErrPreAuthKeyExpired) c.Assert(key, check.IsNil) } func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { - user, err := app.CreateUser("test6") + user, err := app.db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true - app.db.Save(&pak) + app.db.db.Save(&pak) - _, err = app.checkKeyValidity(pak.Key) + _, err = app.db.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) } func (*Suite) TestPreAuthKeyACLTags(c *check.C) { - user, err := app.CreateUser("test8") + user, err := app.db.CreateUser("test8") c.Assert(err, check.IsNil) - _, err = app.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) + _, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected tags := []string{"tag:test1", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = app.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) + _, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) c.Assert(err, check.IsNil) - listedPaks, err := app.ListPreAuthKeys("test8") + listedPaks, err := app.db.ListPreAuthKeys("test8") c.Assert(err, check.IsNil) c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags) } diff --git a/hscontrol/protocol_common.go b/hscontrol/protocol_common.go index 97da464..5cd0ddb 100644 --- a/hscontrol/protocol_common.go +++ b/hscontrol/protocol_common.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -82,7 +83,7 @@ func (h *Headscale) KeyHandler( // Old clients don't send a 'v' parameter, so we send the legacy public key writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public()))) + _, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey.Public()))) if err != nil { log.Error(). Caller(). @@ -102,7 +103,7 @@ func (h *Headscale) handleRegisterCommon( isNoise bool, ) { now := time.Now().UTC() - machine, err := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) + machine, err := h.db.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) if errors.Is(err, gorm.ErrRecordNotFound) { // If the machine has AuthKey set, handle registration via PreAuthKeys if registerRequest.Auth.AuthKey != "" { @@ -120,7 +121,7 @@ func (h *Headscale) handleRegisterCommon( // is that the client will hammer headscale with requests until it gets a // successful RegisterResponse. if registerRequest.Followup != "" { - if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { + if _, ok := h.registrationCache.Get(util.NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { log.Debug(). Caller(). Str("machine", registerRequest.Hostinfo.Hostname). @@ -152,7 +153,7 @@ func (h *Headscale) handleRegisterCommon( Bool("noise", isNoise). Msg("New machine not yet in the database") - givenName, err := h.GenerateGivenName( + givenName, err := h.db.GenerateGivenName( machineKey.String(), registerRequest.Hostinfo.Hostname, ) @@ -171,10 +172,10 @@ func (h *Headscale) handleRegisterCommon( // We create the machine and then keep it around until a callback // happens newMachine := Machine{ - MachineKey: MachinePublicKeyStripPrefix(machineKey), + MachineKey: util.MachinePublicKeyStripPrefix(machineKey), Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, - NodeKey: NodePublicKeyStripPrefix(registerRequest.NodeKey), + NodeKey: util.NodePublicKeyStripPrefix(registerRequest.NodeKey), LastSeen: &now, Expiry: &time.Time{}, } @@ -210,11 +211,11 @@ func (h *Headscale) handleRegisterCommon( // So if we have a not valid MachineKey (but we were able to fetch the machine with the NodeKeys), we update it. var storedMachineKey key.MachinePublic err = storedMachineKey.UnmarshalText( - []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), + []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), ) if err != nil || storedMachineKey.IsZero() { - machine.MachineKey = MachinePublicKeyStripPrefix(machineKey) - if err := h.db.Save(&machine).Error; err != nil { + machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey) + if err := h.db.db.Save(&machine).Error; err != nil { log.Error(). Caller(). Str("func", "RegistrationHandler"). @@ -231,7 +232,7 @@ func (h *Headscale) handleRegisterCommon( // - Trying to log out (sending a expiry in the past) // - A valid, registered machine, looking for /map // - Expired machine wanting to reauthenticate - if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) { + if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.NodeKey) { // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 if !registerRequest.Expiry.IsZero() && @@ -251,7 +252,7 @@ func (h *Headscale) handleRegisterCommon( } // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration - if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && + if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && !machine.isExpired() { h.handleMachineRefreshKeyCommon( writer, @@ -282,9 +283,9 @@ func (h *Headscale) handleRegisterCommon( // we need to make sure the NodeKey matches the one in the request // TODO(juan): What happens when using fast user switching between two // headscale-managed tailnets? - machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) + machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey) h.registrationCache.Set( - NodePublicKeyStripPrefix(registerRequest.NodeKey), + util.NodePublicKeyStripPrefix(registerRequest.NodeKey), *machine, registerCacheExpiration, ) @@ -311,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon( Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) resp := tailcfg.RegisterResponse{} - pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey) + pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey) if err != nil { log.Error(). Caller(). @@ -372,13 +373,13 @@ func (h *Headscale) handleAuthKeyCommon( Str("machine", registerRequest.Hostinfo.Hostname). Msg("Authentication key was valid, proceeding to acquire IP addresses") - nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey) + nodeKey := util.NodePublicKeyStripPrefix(registerRequest.NodeKey) // retrieve machine information if it exist // The error is not important, because if it does not // exist, then this is a new machine and we will move // on to registration. - machine, _ := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) + machine, _ := h.db.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) if machine != nil { log.Trace(). Caller(). @@ -388,7 +389,7 @@ func (h *Headscale) handleAuthKeyCommon( machine.NodeKey = nodeKey machine.AuthKeyID = uint(pak.ID) - err := h.RefreshMachine(machine, registerRequest.Expiry) + err := h.db.RefreshMachine(machine, registerRequest.Expiry) if err != nil { log.Error(). Caller(). @@ -403,7 +404,7 @@ func (h *Headscale) handleAuthKeyCommon( aclTags := pak.toProto().AclTags if len(aclTags) > 0 { // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login - err = h.SetTags(machine, aclTags) + err = h.db.SetTags(machine, aclTags, h.UpdateACLRules) if err != nil { log.Error(). @@ -420,7 +421,7 @@ func (h *Headscale) handleAuthKeyCommon( } else { now := time.Now().UTC() - givenName, err := h.GenerateGivenName(MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname) + givenName, err := h.db.GenerateGivenName(util.MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname) if err != nil { log.Error(). Caller(). @@ -436,7 +437,7 @@ func (h *Headscale) handleAuthKeyCommon( Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, UserID: pak.User.ID, - MachineKey: MachinePublicKeyStripPrefix(machineKey), + MachineKey: util.MachinePublicKeyStripPrefix(machineKey), RegisterMethod: RegisterMethodAuthKey, Expiry: ®isterRequest.Expiry, NodeKey: nodeKey, @@ -445,7 +446,7 @@ func (h *Headscale) handleAuthKeyCommon( ForcedTags: pak.toProto().AclTags, } - machine, err = h.RegisterMachine( + machine, err = h.db.RegisterMachine( machineToRegister, ) if err != nil { @@ -462,7 +463,7 @@ func (h *Headscale) handleAuthKeyCommon( } } - err = h.UsePreAuthKey(pak) + err = h.db.UsePreAuthKey(pak) if err != nil { log.Error(). Caller(). @@ -591,7 +592,7 @@ func (h *Headscale) handleMachineLogOutCommon( Str("machine", machine.Hostname). Msg("Client requested logout") - err := h.ExpireMachine(&machine) + err := h.db.ExpireMachine(&machine) if err != nil { log.Error(). Caller(). @@ -634,7 +635,7 @@ func (h *Headscale) handleMachineLogOutCommon( } if machine.isEphemeral() { - err = h.HardDeleteMachine(&machine) + err = h.db.HardDeleteMachine(&machine) if err != nil { log.Error(). Err(err). @@ -720,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon( Bool("noise", isNoise). Str("machine", machine.Hostname). Msg("We have the OldNodeKey in the database. This is a key refresh") - machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) + machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey) - if err := h.db.Save(&machine).Error; err != nil { + if err := h.db.db.Save(&machine).Error; err != nil { log.Error(). Caller(). Err(err). diff --git a/hscontrol/protocol_common_poll.go b/hscontrol/protocol_common_poll.go index f267c99..502c633 100644 --- a/hscontrol/protocol_common_poll.go +++ b/hscontrol/protocol_common_poll.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" ) @@ -29,10 +30,10 @@ func (h *Headscale) handlePollCommon( ) { machine.Hostname = mapRequest.Hostinfo.Hostname machine.HostInfo = HostInfo(*mapRequest.Hostinfo) - machine.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) + machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) now := time.Now().UTC() - err := h.processMachineRoutes(machine) + err := h.db.processMachineRoutes(machine) if err != nil { log.Error(). Caller(). @@ -53,7 +54,7 @@ func (h *Headscale) handlePollCommon( } // update routes with peer information - err = h.EnableAutoApprovedRoutes(machine) + err = h.db.EnableAutoApprovedRoutes(h.aclPolicy, machine) if err != nil { log.Error(). Caller(). @@ -77,7 +78,7 @@ func (h *Headscale) handlePollCommon( machine.LastSeen = &now } - if err := h.db.Updates(machine).Error; err != nil { + if err := h.db.db.Updates(machine).Error; err != nil { if err != nil { log.Error(). Str("handler", "PollNetMap"). @@ -325,7 +326,7 @@ func (h *Headscale) pollNetMapStream( // TODO(kradalby): Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err = h.UpdateMachineFromDatabase(machine) + err = h.db.UpdateMachineFromDatabase(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -346,7 +347,7 @@ func (h *Headscale) pollNetMapStream( Set(float64(now.Unix())) machine.LastSuccessfulUpdate = &now - err = h.TouchMachine(machine) + err = h.db.TouchMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -409,7 +410,7 @@ func (h *Headscale) pollNetMapStream( // TODO(kradalby): Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err = h.UpdateMachineFromDatabase(machine) + err = h.db.UpdateMachineFromDatabase(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -425,7 +426,7 @@ func (h *Headscale) pollNetMapStream( } now := time.Now().UTC() machine.LastSeen = &now - err = h.TouchMachine(machine) + err = h.db.TouchMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -456,7 +457,7 @@ func (h *Headscale) pollNetMapStream( updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname). Inc() - if h.isOutdated(machine) { + if h.db.isOutdated(machine, h.getLastStateChange()) { var lastUpdate time.Time if machine.LastSuccessfulUpdate != nil { lastUpdate = *machine.LastSuccessfulUpdate @@ -524,7 +525,7 @@ func (h *Headscale) pollNetMapStream( // TODO(kradalby): Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err = h.UpdateMachineFromDatabase(machine) + err = h.db.UpdateMachineFromDatabase(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -544,7 +545,7 @@ func (h *Headscale) pollNetMapStream( Set(float64(now.Unix())) machine.LastSuccessfulUpdate = &now - err = h.TouchMachine(machine) + err = h.db.TouchMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -578,7 +579,7 @@ func (h *Headscale) pollNetMapStream( // TODO: Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err := h.UpdateMachineFromDatabase(machine) + err := h.db.UpdateMachineFromDatabase(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -594,7 +595,7 @@ func (h *Headscale) pollNetMapStream( } now := time.Now().UTC() machine.LastSeen = &now - err = h.TouchMachine(machine) + err = h.db.TouchMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). diff --git a/hscontrol/protocol_common_utils.go b/hscontrol/protocol_common_utils.go index e05b04a..1dababa 100644 --- a/hscontrol/protocol_common_utils.go +++ b/hscontrol/protocol_common_utils.go @@ -5,6 +5,7 @@ import ( "encoding/json" "sync" + "github.com/juanfont/headscale/hscontrol/util" "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" "tailscale.com/smallzstd" @@ -27,7 +28,7 @@ func (h *Headscale) getMapResponseData( } var machineKey key.MachinePublic - err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey))) + err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) if err != nil { log.Error(). Caller(). @@ -50,11 +51,16 @@ func (h *Headscale) getMapKeepAliveResponseData( } if isNoise { - return h.marshalMapResponse(keepAliveResponse, key.MachinePublic{}, mapRequest.Compress, isNoise) + return h.marshalMapResponse( + keepAliveResponse, + key.MachinePublic{}, + mapRequest.Compress, + isNoise, + ) } var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey))) + err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) if err != nil { log.Error(). Caller(). @@ -104,7 +110,7 @@ func (h *Headscale) marshalMapResponse( } var respBody []byte - if compression == ZstdCompression { + if compression == util.ZstdCompression { respBody = zstdEncode(jsonBody) if !isNoise { // if legacy protocol respBody = h.privateKey.SealTo(machineKey, respBody) diff --git a/hscontrol/protocol_legacy.go b/hscontrol/protocol_legacy.go index 6712828..f443eba 100644 --- a/hscontrol/protocol_legacy.go +++ b/hscontrol/protocol_legacy.go @@ -7,6 +7,7 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -32,7 +33,7 @@ func (h *Headscale) RegistrationHandler( body, _ := io.ReadAll(req.Body) var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) + err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr))) if err != nil { log.Error(). Caller(). @@ -44,7 +45,7 @@ func (h *Headscale) RegistrationHandler( return } registerRequest := tailcfg.RegisterRequest{} - err = decode(body, ®isterRequest, &machineKey, h.privateKey) + err = util.DecodeAndUnmarshalNaCl(body, ®isterRequest, &machineKey, h.privateKey) if err != nil { log.Error(). Caller(). diff --git a/hscontrol/protocol_legacy_poll.go b/hscontrol/protocol_legacy_poll.go index 0121bf3..3755faf 100644 --- a/hscontrol/protocol_legacy_poll.go +++ b/hscontrol/protocol_legacy_poll.go @@ -8,6 +8,7 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -44,7 +45,7 @@ func (h *Headscale) PollNetMapHandler( body, _ := io.ReadAll(req.Body) var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) + err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr))) if err != nil { log.Error(). Str("handler", "PollNetMap"). @@ -56,7 +57,7 @@ func (h *Headscale) PollNetMapHandler( return } mapRequest := tailcfg.MapRequest{} - err = decode(body, &mapRequest, &machineKey, h.privateKey) + err = util.DecodeAndUnmarshalNaCl(body, &mapRequest, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "PollNetMap"). @@ -67,7 +68,7 @@ func (h *Headscale) PollNetMapHandler( return } - machine, err := h.GetMachineByMachineKey(machineKey) + machine, err := h.db.GetMachineByMachineKey(machineKey) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Warn(). diff --git a/hscontrol/protocol_noise_poll.go b/hscontrol/protocol_noise_poll.go index 38f2b1c..c0790f9 100644 --- a/hscontrol/protocol_noise_poll.go +++ b/hscontrol/protocol_noise_poll.go @@ -48,7 +48,11 @@ func (ns *noiseServer) NoisePollNetMapHandler( ns.nodeKey = mapRequest.NodeKey - machine, err := ns.headscale.GetMachineByAnyKey(ns.conn.Peer(), mapRequest.NodeKey, key.NodePublic{}) + machine, err := ns.headscale.db.GetMachineByAnyKey( + ns.conn.Peer(), + mapRequest.NodeKey, + key.NodePublic{}, + ) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Warn(). diff --git a/hscontrol/routes.go b/hscontrol/routes.go index 89f9a69..e3be2f6 100644 --- a/hscontrol/routes.go +++ b/hscontrol/routes.go @@ -11,13 +11,10 @@ import ( "gorm.io/gorm" ) -const ( - ErrRouteIsNotAvailable = Error("route is not available") -) - var ( - ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") - ExitRouteV6 = netip.MustParsePrefix("::/0") + ErrRouteIsNotAvailable = errors.New("route is not available") + ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") + ExitRouteV6 = netip.MustParsePrefix("::/0") ) type Route struct { @@ -51,9 +48,9 @@ func (rs Routes) toPrefixes() []netip.Prefix { return prefixes } -func (h *Headscale) GetRoutes() ([]Route, error) { +func (hsdb *HSDatabase) GetRoutes() ([]Route, error) { var routes []Route - err := h.db.Preload("Machine").Find(&routes).Error + err := hsdb.db.Preload("Machine").Find(&routes).Error if err != nil { return nil, err } @@ -61,9 +58,9 @@ func (h *Headscale) GetRoutes() ([]Route, error) { return routes, nil } -func (h *Headscale) GetMachineRoutes(m *Machine) ([]Route, error) { +func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) { var routes []Route - err := h.db. + err := hsdb.db. Preload("Machine"). Where("machine_id = ?", m.ID). Find(&routes).Error @@ -74,9 +71,9 @@ func (h *Headscale) GetMachineRoutes(m *Machine) ([]Route, error) { return routes, nil } -func (h *Headscale) GetRoute(id uint64) (*Route, error) { +func (hsdb *HSDatabase) GetRoute(id uint64) (*Route, error) { var route Route - err := h.db.Preload("Machine").First(&route, id).Error + err := hsdb.db.Preload("Machine").First(&route, id).Error if err != nil { return nil, err } @@ -84,8 +81,8 @@ func (h *Headscale) GetRoute(id uint64) (*Route, error) { return &route, nil } -func (h *Headscale) EnableRoute(id uint64) error { - route, err := h.GetRoute(id) +func (hsdb *HSDatabase) EnableRoute(id uint64) error { + route, err := hsdb.GetRoute(id) if err != nil { return err } @@ -94,14 +91,14 @@ func (h *Headscale) EnableRoute(id uint64) error { // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 if route.isExitRoute() { - return h.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String()) + return hsdb.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String()) } - return h.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String()) + return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String()) } -func (h *Headscale) DisableRoute(id uint64) error { - route, err := h.GetRoute(id) +func (hsdb *HSDatabase) DisableRoute(id uint64) error { + route, err := hsdb.GetRoute(id) if err != nil { return err } @@ -112,15 +109,15 @@ func (h *Headscale) DisableRoute(id uint64) error { if !route.isExitRoute() { route.Enabled = false route.IsPrimary = false - err = h.db.Save(route).Error + err = hsdb.db.Save(route).Error if err != nil { return err } - return h.handlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } - routes, err := h.GetMachineRoutes(&route.Machine) + routes, err := hsdb.GetMachineRoutes(&route.Machine) if err != nil { return err } @@ -129,18 +126,18 @@ func (h *Headscale) DisableRoute(id uint64) error { if routes[i].isExitRoute() { routes[i].Enabled = false routes[i].IsPrimary = false - err = h.db.Save(&routes[i]).Error + err = hsdb.db.Save(&routes[i]).Error if err != nil { return err } } } - return h.handlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } -func (h *Headscale) DeleteRoute(id uint64) error { - route, err := h.GetRoute(id) +func (hsdb *HSDatabase) DeleteRoute(id uint64) error { + route, err := hsdb.GetRoute(id) if err != nil { return err } @@ -149,14 +146,14 @@ func (h *Headscale) DeleteRoute(id uint64) error { // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 if !route.isExitRoute() { - if err := h.db.Unscoped().Delete(&route).Error; err != nil { + if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { return err } - return h.handlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } - routes, err := h.GetMachineRoutes(&route.Machine) + routes, err := hsdb.GetMachineRoutes(&route.Machine) if err != nil { return err } @@ -168,32 +165,32 @@ func (h *Headscale) DeleteRoute(id uint64) error { } } - if err := h.db.Unscoped().Delete(&routesToDelete).Error; err != nil { + if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil { return err } - return h.handlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } -func (h *Headscale) DeleteMachineRoutes(m *Machine) error { - routes, err := h.GetMachineRoutes(m) +func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error { + routes, err := hsdb.GetMachineRoutes(m) if err != nil { return err } for i := range routes { - if err := h.db.Unscoped().Delete(&routes[i]).Error; err != nil { + if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil { return err } } - return h.handlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } // isUniquePrefix returns if there is another machine providing the same route already. -func (h *Headscale) isUniquePrefix(route Route) bool { +func (hsdb *HSDatabase) isUniquePrefix(route Route) bool { var count int64 - h.db. + hsdb.db. Model(&Route{}). Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", route.Prefix, @@ -203,9 +200,9 @@ func (h *Headscale) isUniquePrefix(route Route) bool { return count == 0 } -func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { +func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { var route Route - err := h.db. + err := hsdb.db. Preload("Machine"). Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true). First(&route).Error @@ -222,9 +219,9 @@ func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary. -func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { +func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { var routes []Route - err := h.db. + err := hsdb.db. Preload("Machine"). Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true). Find(&routes).Error @@ -235,9 +232,9 @@ func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { return routes, nil } -func (h *Headscale) processMachineRoutes(machine *Machine) error { +func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { currentRoutes := []Route{} - err := h.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error + err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error if err != nil { return err } @@ -251,7 +248,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error { if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { if !route.Advertised { currentRoutes[pos].Advertised = true - err := h.db.Save(¤tRoutes[pos]).Error + err := hsdb.db.Save(¤tRoutes[pos]).Error if err != nil { return err } @@ -260,7 +257,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error { } else if route.Advertised { currentRoutes[pos].Advertised = false currentRoutes[pos].Enabled = false - err := h.db.Save(¤tRoutes[pos]).Error + err := hsdb.db.Save(¤tRoutes[pos]).Error if err != nil { return err } @@ -275,7 +272,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error { Advertised: true, Enabled: false, } - err := h.db.Create(&route).Error + err := hsdb.db.Create(&route).Error if err != nil { return err } @@ -285,10 +282,10 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error { return nil } -func (h *Headscale) handlePrimarySubnetFailover() error { +func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { // first, get all the enabled routes var routes []Route - err := h.db. + err := hsdb.db. Preload("Machine"). Where("advertised = ? AND enabled = ?", true, true). Find(&routes).Error @@ -303,14 +300,14 @@ func (h *Headscale) handlePrimarySubnetFailover() error { } if !route.IsPrimary { - _, err := h.getPrimaryRoute(netip.Prefix(route.Prefix)) - if h.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) { + _, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix)) + if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) { log.Info(). Str("prefix", netip.Prefix(route.Prefix).String()). Str("machine", route.Machine.GivenName). Msg("Setting primary route") routes[pos].IsPrimary = true - err := h.db.Save(&routes[pos]).Error + err := hsdb.db.Save(&routes[pos]).Error if err != nil { log.Error().Err(err).Msg("error marking route as primary") @@ -336,7 +333,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error { // find a new primary route var newPrimaryRoutes []Route - err := h.db. + err := hsdb.db. Preload("Machine"). Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", route.Prefix, @@ -375,7 +372,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error { // disable the old primary route routes[pos].IsPrimary = false - err = h.db.Save(&routes[pos]).Error + err = hsdb.db.Save(&routes[pos]).Error if err != nil { log.Error().Err(err).Msg("error disabling old primary route") @@ -384,7 +381,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error { // enable the new primary route newPrimaryRoute.IsPrimary = true - err = h.db.Save(&newPrimaryRoute).Error + err = hsdb.db.Save(&newPrimaryRoute).Error if err != nil { log.Error().Err(err).Msg("error enabling new primary route") @@ -396,7 +393,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error { } if routesChanged { - h.setLastStateChangeToNow() + hsdb.notifyStateChange() } return nil diff --git a/hscontrol/routes_test.go b/hscontrol/routes_test.go index 1e5e2bb..cf437a4 100644 --- a/hscontrol/routes_test.go +++ b/hscontrol/routes_test.go @@ -4,19 +4,20 @@ import ( "net/netip" "time" + "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "tailscale.com/tailcfg" "tailscale.com/types/key" ) func (s *Suite) TestGetRoutes(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_get_route_machine") + _, err = app.db.GetMachine("test", "test_get_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix("10.0.0.0/24") @@ -37,30 +38,30 @@ func (s *Suite) TestGetRoutes(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.processMachineRoutes(&machine) + err = app.db.processMachineRoutes(&machine) c.Assert(err, check.IsNil) - advertisedRoutes, err := app.GetAdvertisedRoutes(&machine) + advertisedRoutes, err := app.db.GetAdvertisedRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(advertisedRoutes), check.Equals, 1) - err = app.enableRoutes(&machine, "192.168.0.0/24") + err = app.db.enableRoutes(&machine, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = app.enableRoutes(&machine, "10.0.0.0/24") + err = app.db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) } func (s *Suite) TestGetEnableRoutes(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_enable_route_machine") + _, err = app.db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -88,54 +89,54 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.processMachineRoutes(&machine) + err = app.db.processMachineRoutes(&machine) c.Assert(err, check.IsNil) - availableRoutes, err := app.GetAdvertisedRoutes(&machine) + availableRoutes, err := app.db.GetAdvertisedRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(err, check.IsNil) c.Assert(len(availableRoutes), check.Equals, 2) - noEnabledRoutes, err := app.GetEnabledRoutes(&machine) + noEnabledRoutes, err := app.db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(noEnabledRoutes), check.Equals, 0) - err = app.enableRoutes(&machine, "192.168.0.0/24") + err = app.db.enableRoutes(&machine, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = app.enableRoutes(&machine, "10.0.0.0/24") + err = app.db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) - enabledRoutes, err := app.GetEnabledRoutes(&machine) + enabledRoutes, err := app.db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes), check.Equals, 1) // Adding it twice will just let it pass through - err = app.enableRoutes(&machine, "10.0.0.0/24") + err = app.db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) - enableRoutesAfterDoubleApply, err := app.GetEnabledRoutes(&machine) + enableRoutesAfterDoubleApply, err := app.db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) - err = app.enableRoutes(&machine, "150.0.10.0/25") + err = app.db.enableRoutes(&machine, "150.0.10.0/25") c.Assert(err, check.IsNil) - enabledRoutesWithAdditionalRoute, err := app.GetEnabledRoutes(&machine) + enabledRoutesWithAdditionalRoute, err := app.db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) } func (s *Suite) TestIsUniquePrefix(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_enable_route_machine") + _, err = app.db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -162,15 +163,15 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo1), } - app.db.Save(&machine1) + app.db.db.Save(&machine1) - err = app.processMachineRoutes(&machine1) + err = app.db.processMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, route.String()) + err = app.db.enableRoutes(&machine1, route.String()) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, route2.String()) + err = app.db.enableRoutes(&machine1, route2.String()) c.Assert(err, check.IsNil) hostInfo2 := tailcfg.Hostinfo{ @@ -187,39 +188,39 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo2), } - app.db.Save(&machine2) + app.db.db.Save(&machine2) - err = app.processMachineRoutes(&machine2) + err = app.db.processMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine2, route2.String()) + err = app.db.enableRoutes(&machine2, route2.String()) c.Assert(err, check.IsNil) - enabledRoutes1, err := app.GetEnabledRoutes(&machine1) + enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - enabledRoutes2, err := app.GetEnabledRoutes(&machine2) + enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes2), check.Equals, 1) - routes, err := app.getMachinePrimaryRoutes(&machine1) + routes, err := app.db.getMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) - routes, err = app.getMachinePrimaryRoutes(&machine2) + routes, err = app.db.getMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) } func (s *Suite) TestSubnetFailover(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_enable_route_machine") + _, err = app.db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -249,25 +250,25 @@ func (s *Suite) TestSubnetFailover(c *check.C) { HostInfo: HostInfo(hostInfo1), LastSeen: &now, } - app.db.Save(&machine1) + app.db.db.Save(&machine1) - err = app.processMachineRoutes(&machine1) + err = app.db.processMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, prefix.String()) + err = app.db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, prefix2.String()) + err = app.db.enableRoutes(&machine1, prefix2.String()) c.Assert(err, check.IsNil) - err = app.handlePrimarySubnetFailover() + err = app.db.handlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err := app.GetEnabledRoutes(&machine1) + enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - route, err := app.getPrimaryRoute(prefix) + route, err := app.db.getPrimaryRoute(prefix) c.Assert(err, check.IsNil) c.Assert(route.MachineID, check.Equals, machine1.ID) @@ -286,70 +287,70 @@ func (s *Suite) TestSubnetFailover(c *check.C) { HostInfo: HostInfo(hostInfo2), LastSeen: &now, } - app.db.Save(&machine2) + app.db.db.Save(&machine2) - err = app.processMachineRoutes(&machine2) + err = app.db.processMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine2, prefix2.String()) + err = app.db.enableRoutes(&machine2, prefix2.String()) c.Assert(err, check.IsNil) - err = app.handlePrimarySubnetFailover() + err = app.db.handlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err = app.GetEnabledRoutes(&machine1) + enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - enabledRoutes2, err := app.GetEnabledRoutes(&machine2) + enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes2), check.Equals, 1) - routes, err := app.getMachinePrimaryRoutes(&machine1) + routes, err := app.db.getMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) - routes, err = app.getMachinePrimaryRoutes(&machine2) + routes, err = app.db.getMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) // lets make machine1 lastseen 10 mins ago before := now.Add(-10 * time.Minute) machine1.LastSeen = &before - err = app.db.Save(&machine1).Error + err = app.db.db.Save(&machine1).Error c.Assert(err, check.IsNil) - err = app.handlePrimarySubnetFailover() + err = app.db.handlePrimarySubnetFailover() c.Assert(err, check.IsNil) - routes, err = app.getMachinePrimaryRoutes(&machine1) + routes, err = app.db.getMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 1) - routes, err = app.getMachinePrimaryRoutes(&machine2) + routes, err = app.db.getMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 1) machine2.HostInfo = HostInfo(tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{prefix, prefix2}, }) - err = app.db.Save(&machine2).Error + err = app.db.db.Save(&machine2).Error c.Assert(err, check.IsNil) - err = app.processMachineRoutes(&machine2) + err = app.db.processMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine2, prefix.String()) + err = app.db.enableRoutes(&machine2, prefix.String()) c.Assert(err, check.IsNil) - err = app.handlePrimarySubnetFailover() + err = app.db.handlePrimarySubnetFailover() c.Assert(err, check.IsNil) - routes, err = app.getMachinePrimaryRoutes(&machine1) + routes, err = app.db.getMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) - routes, err = app.getMachinePrimaryRoutes(&machine2) + routes, err = app.db.getMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) } @@ -358,13 +359,13 @@ func (s *Suite) TestSubnetFailover(c *check.C) { // including both the primary routes the node is responsible for, and the // exit node routes if enabled. func (s *Suite) TestAllowedIPRoutes(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_enable_route_machine") + _, err = app.db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -398,9 +399,9 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { now := time.Now() machine1 := Machine{ ID: 1, - MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), - DiscoKey: DiscoPublicKeyStripPrefix(discoKey.Public()), + MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), + DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()), Hostname: "test_enable_route_machine", UserID: user.ID, RegisterMethod: RegisterMethodAuthKey, @@ -408,23 +409,23 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { HostInfo: HostInfo(hostInfo1), LastSeen: &now, } - app.db.Save(&machine1) + app.db.db.Save(&machine1) - err = app.processMachineRoutes(&machine1) + err = app.db.processMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, prefix.String()) + err = app.db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) // We do not enable this one on purpose to test that it is not enabled - // err = app.enableRoutes(&machine1, prefix2.String()) + // err = app.db.enableRoutes(&machine1, prefix2.String()) // c.Assert(err, check.IsNil) - routes, err := app.GetMachineRoutes(&machine1) + routes, err := app.db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) for _, route := range routes { if route.isExitRoute() { - err = app.EnableRoute(uint64(route.ID)) + err = app.db.EnableRoute(uint64(route.ID)) c.Assert(err, check.IsNil) // We only enable one exit route, so we can test that both are enabled @@ -432,14 +433,14 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { } } - err = app.handlePrimarySubnetFailover() + err = app.db.handlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err := app.GetEnabledRoutes(&machine1) + enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 3) - peer, err := app.toNode(machine1, "headscale.net", nil) + peer, err := app.db.toNode(machine1, app.aclPolicy, "headscale.net", nil) c.Assert(err, check.IsNil) c.Assert(len(peer.AllowedIPs), check.Equals, 3) @@ -469,35 +470,35 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { } } - err = app.DisableRoute(uint64(exitRouteV4.ID)) + err = app.db.DisableRoute(uint64(exitRouteV4.ID)) c.Assert(err, check.IsNil) - enabledRoutes1, err = app.GetEnabledRoutes(&machine1) + enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 1) // and now we delete only one of the exit routes // and we check if both are deleted - routes, err = app.GetMachineRoutes(&machine1) + routes, err = app.db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 4) - err = app.DeleteRoute(uint64(exitRouteV4.ID)) + err = app.db.DeleteRoute(uint64(exitRouteV4.ID)) c.Assert(err, check.IsNil) - routes, err = app.GetMachineRoutes(&machine1) + routes, err = app.db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) } func (s *Suite) TestDeleteRoutes(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_enable_route_machine") + _, err = app.db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -527,24 +528,24 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { HostInfo: HostInfo(hostInfo1), LastSeen: &now, } - app.db.Save(&machine1) + app.db.db.Save(&machine1) - err = app.processMachineRoutes(&machine1) + err = app.db.processMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, prefix.String()) + err = app.db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, prefix2.String()) + err = app.db.enableRoutes(&machine1, prefix2.String()) c.Assert(err, check.IsNil) - routes, err := app.GetMachineRoutes(&machine1) + routes, err := app.db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.DeleteRoute(uint64(routes[0].ID)) + err = app.db.DeleteRoute(uint64(routes[0].ID)) c.Assert(err, check.IsNil) - enabledRoutes1, err := app.GetEnabledRoutes(&machine1) + enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 1) } diff --git a/hscontrol/users.go b/hscontrol/users.go index 8782a89..fb3cea9 100644 --- a/hscontrol/users.go +++ b/hscontrol/users.go @@ -9,17 +9,18 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" "tailscale.com/tailcfg" ) -const ( - ErrUserExists = Error("User already exists") - ErrUserNotFound = Error("User not found") - ErrUserStillHasNodes = Error("User not empty: node(s) found") - ErrInvalidUserName = Error("Invalid user name") +var ( + ErrUserExists = errors.New("user already exists") + ErrUserNotFound = errors.New("user not found") + ErrUserStillHasNodes = errors.New("user not empty: node(s) found") + ErrInvalidUserName = errors.New("invalid user name") ) const ( @@ -40,17 +41,17 @@ type User struct { // CreateUser creates a new User. Returns error if could not be created // or another user already exists. -func (h *Headscale) CreateUser(name string) (*User, error) { +func (hsdb *HSDatabase) CreateUser(name string) (*User, error) { err := CheckForFQDNRules(name) if err != nil { return nil, err } user := User{} - if err := h.db.Where("name = ?", name).First(&user).Error; err == nil { + if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { return nil, ErrUserExists } user.Name = name - if err := h.db.Create(&user).Error; err != nil { + if err := hsdb.db.Create(&user).Error; err != nil { log.Error(). Str("func", "CreateUser"). Err(err). @@ -64,13 +65,13 @@ func (h *Headscale) CreateUser(name string) (*User, error) { // DestroyUser destroys a User. Returns error if the User does // not exist or if there are machines associated with it. -func (h *Headscale) DestroyUser(name string) error { - user, err := h.GetUser(name) +func (hsdb *HSDatabase) DestroyUser(name string) error { + user, err := hsdb.GetUser(name) if err != nil { return ErrUserNotFound } - machines, err := h.ListMachinesByUser(name) + machines, err := hsdb.ListMachinesByUser(name) if err != nil { return err } @@ -78,18 +79,18 @@ func (h *Headscale) DestroyUser(name string) error { return ErrUserStillHasNodes } - keys, err := h.ListPreAuthKeys(name) + keys, err := hsdb.ListPreAuthKeys(name) if err != nil { return err } for _, key := range keys { - err = h.DestroyPreAuthKey(key) + err = hsdb.DestroyPreAuthKey(key) if err != nil { return err } } - if result := h.db.Unscoped().Delete(&user); result.Error != nil { + if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil { return result.Error } @@ -98,9 +99,9 @@ func (h *Headscale) DestroyUser(name string) error { // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. -func (h *Headscale) RenameUser(oldName, newName string) error { +func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { var err error - oldUser, err := h.GetUser(oldName) + oldUser, err := hsdb.GetUser(oldName) if err != nil { return err } @@ -108,7 +109,7 @@ func (h *Headscale) RenameUser(oldName, newName string) error { if err != nil { return err } - _, err = h.GetUser(newName) + _, err = hsdb.GetUser(newName) if err == nil { return ErrUserExists } @@ -118,7 +119,7 @@ func (h *Headscale) RenameUser(oldName, newName string) error { oldUser.Name = newName - if result := h.db.Save(&oldUser); result.Error != nil { + if result := hsdb.db.Save(&oldUser); result.Error != nil { return result.Error } @@ -126,9 +127,9 @@ func (h *Headscale) RenameUser(oldName, newName string) error { } // GetUser fetches a user by name. -func (h *Headscale) GetUser(name string) (*User, error) { +func (hsdb *HSDatabase) GetUser(name string) (*User, error) { user := User{} - if result := h.db.First(&user, "name = ?", name); errors.Is( + if result := hsdb.db.First(&user, "name = ?", name); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -139,9 +140,9 @@ func (h *Headscale) GetUser(name string) (*User, error) { } // ListUsers gets all the existing users. -func (h *Headscale) ListUsers() ([]User, error) { +func (hsdb *HSDatabase) ListUsers() ([]User, error) { users := []User{} - if err := h.db.Find(&users).Error; err != nil { + if err := hsdb.db.Find(&users).Error; err != nil { return nil, err } @@ -149,18 +150,18 @@ func (h *Headscale) ListUsers() ([]User, error) { } // ListMachinesByUser gets all the nodes in a given user. -func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) { +func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { err := CheckForFQDNRules(name) if err != nil { return nil, err } - user, err := h.GetUser(name) + user, err := hsdb.GetUser(name) if err != nil { return nil, err } machines := []Machine{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil { + if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil { return nil, err } @@ -168,17 +169,17 @@ func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) { } // SetMachineUser assigns a Machine to a user. -func (h *Headscale) SetMachineUser(machine *Machine, username string) error { +func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error { err := CheckForFQDNRules(username) if err != nil { return err } - user, err := h.GetUser(username) + user, err := hsdb.GetUser(username) if err != nil { return err } machine.User = *user - if result := h.db.Save(&machine); result.Error != nil { + if result := hsdb.db.Save(&machine); result.Error != nil { return result.Error } @@ -211,7 +212,7 @@ func (n *User) toTailscaleLogin() *tailcfg.Login { return &login } -func (h *Headscale) getMapResponseUserProfiles( +func (hsdb *HSDatabase) getMapResponseUserProfiles( machine Machine, peers Machines, ) []tailcfg.UserProfile { @@ -225,8 +226,8 @@ func (h *Headscale) getMapResponseUserProfiles( for _, user := range userMap { displayName := user.Name - if h.cfg.BaseDomain != "" { - displayName = fmt.Sprintf("%s@%s", user.Name, h.cfg.BaseDomain) + if hsdb.baseDomain != "" { + displayName = fmt.Sprintf("%s@%s", user.Name, hsdb.baseDomain) } profiles = append(profiles, @@ -242,7 +243,7 @@ func (h *Headscale) getMapResponseUserProfiles( func (n *User) toProto() *v1.User { return &v1.User{ - Id: strconv.FormatUint(uint64(n.ID), Base10), + Id: strconv.FormatUint(uint64(n.ID), util.Base10), Name: n.Name, CreatedAt: timestamppb.New(n.CreatedAt), } diff --git a/hscontrol/users_test.go b/hscontrol/users_test.go index 12aa988..1d68f92 100644 --- a/hscontrol/users_test.go +++ b/hscontrol/users_test.go @@ -9,42 +9,42 @@ import ( ) func (s *Suite) TestCreateAndDestroyUser(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) c.Assert(user.Name, check.Equals, "test") - users, err := app.ListUsers() + users, err := app.db.ListUsers() c.Assert(err, check.IsNil) c.Assert(len(users), check.Equals, 1) - err = app.DestroyUser("test") + err = app.db.DestroyUser("test") c.Assert(err, check.IsNil) - _, err = app.GetUser("test") + _, err = app.db.GetUser("test") c.Assert(err, check.NotNil) } func (s *Suite) TestDestroyUserErrors(c *check.C) { - err := app.DestroyUser("test") + err := app.db.DestroyUser("test") c.Assert(err, check.Equals, ErrUserNotFound) - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - err = app.DestroyUser("test") + err = app.db.DestroyUser("test") c.Assert(err, check.IsNil) - result := app.db.Preload("User").First(&pak, "key = ?", pak.Key) + result := app.db.db.Preload("User").First(&pak, "key = ?", pak.Key) // destroying a user also deletes all associated preauthkeys c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) - user, err = app.CreateUser("test") + user, err = app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err = app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) machine := Machine{ @@ -57,52 +57,52 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.DestroyUser("test") + err = app.db.DestroyUser("test") c.Assert(err, check.Equals, ErrUserStillHasNodes) } func (s *Suite) TestRenameUser(c *check.C) { - userTest, err := app.CreateUser("test") + userTest, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) c.Assert(userTest.Name, check.Equals, "test") - users, err := app.ListUsers() + users, err := app.db.ListUsers() c.Assert(err, check.IsNil) c.Assert(len(users), check.Equals, 1) - err = app.RenameUser("test", "test-renamed") + err = app.db.RenameUser("test", "test-renamed") c.Assert(err, check.IsNil) - _, err = app.GetUser("test") + _, err = app.db.GetUser("test") c.Assert(err, check.Equals, ErrUserNotFound) - _, err = app.GetUser("test-renamed") + _, err = app.db.GetUser("test-renamed") c.Assert(err, check.IsNil) - err = app.RenameUser("test-does-not-exit", "test") + err = app.db.RenameUser("test-does-not-exit", "test") c.Assert(err, check.Equals, ErrUserNotFound) - userTest2, err := app.CreateUser("test2") + userTest2, err := app.db.CreateUser("test2") c.Assert(err, check.IsNil) c.Assert(userTest2.Name, check.Equals, "test2") - err = app.RenameUser("test2", "test-renamed") + err = app.db.RenameUser("test2", "test-renamed") c.Assert(err, check.Equals, ErrUserExists) } func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { - userShared1, err := app.CreateUser("shared1") + userShared1, err := app.db.CreateUser("shared1") c.Assert(err, check.IsNil) - userShared2, err := app.CreateUser("shared2") + userShared2, err := app.db.CreateUser("shared2") c.Assert(err, check.IsNil) - userShared3, err := app.CreateUser("shared3") + userShared3, err := app.db.CreateUser("shared3") c.Assert(err, check.IsNil) - preAuthKeyShared1, err := app.CreatePreAuthKey( + preAuthKeyShared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -111,7 +111,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyShared2, err := app.CreatePreAuthKey( + preAuthKeyShared2, err := app.db.CreatePreAuthKey( userShared2.Name, false, false, @@ -120,7 +120,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyShared3, err := app.CreatePreAuthKey( + preAuthKeyShared3, err := app.db.CreatePreAuthKey( userShared3.Name, false, false, @@ -129,7 +129,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKey2Shared1, err := app.CreatePreAuthKey( + preAuthKey2Shared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -138,7 +138,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { ) c.Assert(err, check.IsNil) - _, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1") + _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") c.Assert(err, check.NotNil) machineInShared1 := &Machine{ @@ -153,9 +153,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, AuthKeyID: uint(preAuthKeyShared1.ID), } - app.db.Save(machineInShared1) + app.db.db.Save(machineInShared1) - _, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname) + _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) c.Assert(err, check.IsNil) machineInShared2 := &Machine{ @@ -170,9 +170,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, AuthKeyID: uint(preAuthKeyShared2.ID), } - app.db.Save(machineInShared2) + app.db.db.Save(machineInShared2) - _, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname) + _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) c.Assert(err, check.IsNil) machineInShared3 := &Machine{ @@ -187,9 +187,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, AuthKeyID: uint(preAuthKeyShared3.ID), } - app.db.Save(machineInShared3) + app.db.db.Save(machineInShared3) - _, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname) + _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) c.Assert(err, check.IsNil) machine2InShared1 := &Machine{ @@ -204,12 +204,12 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, AuthKeyID: uint(preAuthKey2Shared1.ID), } - app.db.Save(machine2InShared1) + app.db.db.Save(machine2InShared1) - peersOfMachine1InShared1, err := app.getPeers(machineInShared1) + peersOfMachine1InShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) c.Assert(err, check.IsNil) - userProfiles := app.getMapResponseUserProfiles( + userProfiles := app.db.getMapResponseUserProfiles( *machineInShared1, peersOfMachine1InShared1, ) @@ -378,13 +378,13 @@ func TestCheckForFQDNRules(t *testing.T) { } func (s *Suite) TestSetMachineUser(c *check.C) { - oldUser, err := app.CreateUser("old") + oldUser, err := app.db.CreateUser("old") c.Assert(err, check.IsNil) - newUser, err := app.CreateUser("new") + newUser, err := app.db.CreateUser("new") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) c.Assert(err, check.IsNil) machine := Machine{ @@ -397,18 +397,18 @@ func (s *Suite) TestSetMachineUser(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) c.Assert(machine.UserID, check.Equals, oldUser.ID) - err = app.SetMachineUser(&machine, newUser.Name) + err = app.db.SetMachineUser(&machine, newUser.Name) c.Assert(err, check.IsNil) c.Assert(machine.UserID, check.Equals, newUser.ID) c.Assert(machine.User.Name, check.Equals, newUser.Name) - err = app.SetMachineUser(&machine, "non-existing-user") + err = app.db.SetMachineUser(&machine, "non-existing-user") c.Assert(err, check.Equals, ErrUserNotFound) - err = app.SetMachineUser(&machine, newUser.Name) + err = app.db.SetMachineUser(&machine, newUser.Name) c.Assert(err, check.IsNil) c.Assert(machine.UserID, check.Equals, newUser.ID) c.Assert(machine.User.Name, check.Equals, newUser.Name) diff --git a/hscontrol/util/addr.go b/hscontrol/util/addr.go new file mode 100644 index 0000000..d312a6e --- /dev/null +++ b/hscontrol/util/addr.go @@ -0,0 +1,42 @@ +package util + +import ( + "net/netip" + "reflect" + + "go4.org/netipx" +) + +func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { + var network, broadcast netip.Addr + ipRange := netipx.RangeOfPrefix(na) + network = ipRange.From() + broadcast = ipRange.To() + + return network, broadcast +} + +func StringToIPPrefix(prefixes []string) ([]netip.Prefix, error) { + result := make([]netip.Prefix, len(prefixes)) + + for index, prefixStr := range prefixes { + prefix, err := netip.ParsePrefix(prefixStr) + if err != nil { + return []netip.Prefix{}, err + } + + result[index] = prefix + } + + return result, nil +} + +func StringOrPrefixListContains[T string | netip.Prefix](ts []T, t T) bool { + for _, v := range ts { + if reflect.DeepEqual(v, t) { + return true + } + } + + return false +} diff --git a/hscontrol/util/file.go b/hscontrol/util/file.go new file mode 100644 index 0000000..7b424da --- /dev/null +++ b/hscontrol/util/file.go @@ -0,0 +1,43 @@ +package util + +import ( + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/spf13/viper" +) + +const ( + Base8 = 8 + Base10 = 10 + BitSize16 = 16 + BitSize32 = 32 + BitSize64 = 64 +) + +func AbsolutePathFromConfigPath(path string) string { + // If a relative path is provided, prefix it with the directory where + // the config file was found. + if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) { + dir, _ := filepath.Split(viper.ConfigFileUsed()) + if dir != "" { + path = filepath.Join(dir, path) + } + } + + return path +} + +func GetFileMode(key string) fs.FileMode { + modeStr := viper.GetString(key) + + mode, err := strconv.ParseUint(modeStr, Base8, BitSize64) + if err != nil { + return PermissionFallback + } + + return fs.FileMode(mode) +} diff --git a/hscontrol/util/key.go b/hscontrol/util/key.go new file mode 100644 index 0000000..4eb1db6 --- /dev/null +++ b/hscontrol/util/key.go @@ -0,0 +1,117 @@ +package util + +import ( + "encoding/json" + "errors" + "regexp" + "strings" + + "tailscale.com/types/key" +) + +const ( + + // These constants are copied from the upstream tailscale.com/types/key + // library, because they are not exported. + // https://github.com/tailscale/tailscale/tree/main/types/key + + // nodePublicHexPrefix is the prefix used to identify a + // hex-encoded node public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + nodePublicHexPrefix = "nodekey:" + + // machinePublicHexPrefix is the prefix used to identify a + // hex-encoded machine public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + machinePublicHexPrefix = "mkey:" + + // discoPublicHexPrefix is the prefix used to identify a + // hex-encoded disco public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + discoPublicHexPrefix = "discokey:" + + // privateKey prefix. + privateHexPrefix = "privkey:" + + PermissionFallback = 0o700 + + ZstdCompression = "zstd" +) + +var ( + NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+") + ErrCannotDecryptResponse = errors.New("cannot decrypt response") +) + +func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string { + return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix) +} + +func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string { + return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix) +} + +func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string { + return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix) +} + +func MachinePublicKeyEnsurePrefix(machineKey string) string { + if !strings.HasPrefix(machineKey, machinePublicHexPrefix) { + return machinePublicHexPrefix + machineKey + } + + return machineKey +} + +func NodePublicKeyEnsurePrefix(nodeKey string) string { + if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) { + return nodePublicHexPrefix + nodeKey + } + + return nodeKey +} + +func DiscoPublicKeyEnsurePrefix(discoKey string) string { + if !strings.HasPrefix(discoKey, discoPublicHexPrefix) { + return discoPublicHexPrefix + discoKey + } + + return discoKey +} + +func PrivateKeyEnsurePrefix(privateKey string) string { + if !strings.HasPrefix(privateKey, privateHexPrefix) { + return privateHexPrefix + privateKey + } + + return privateKey +} + +func DecodeAndUnmarshalNaCl( + msg []byte, + output interface{}, + pubKey *key.MachinePublic, + privKey *key.MachinePrivate, +) error { + // log.Trace(). + // Str("pubkey", pubKey.ShortString()). + // Int("length", len(msg)). + // Msg("Trying to decrypt") + + decrypted, ok := privKey.OpenFrom(*pubKey, msg) + if !ok { + return ErrCannotDecryptResponse + } + + if err := json.Unmarshal(decrypted, output); err != nil { + return err + } + + return nil +} diff --git a/hscontrol/util/net.go b/hscontrol/util/net.go new file mode 100644 index 0000000..b704c93 --- /dev/null +++ b/hscontrol/util/net.go @@ -0,0 +1,12 @@ +package util + +import ( + "context" + "net" +) + +func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + + return d.DialContext(ctx, "unix", addr) +} diff --git a/hscontrol/util/string.go b/hscontrol/util/string.go new file mode 100644 index 0000000..6f018af --- /dev/null +++ b/hscontrol/util/string.go @@ -0,0 +1,85 @@ +package util + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "strings" + + "tailscale.com/tailcfg" +) + +// GenerateRandomBytes returns securely generated random bytes. +// It will return an error if the system's secure random +// number generator fails to function correctly, in which +// case the caller should not continue. +func GenerateRandomBytes(n int) ([]byte, error) { + bytes := make([]byte, n) + + // Note that err == nil only if we read len(b) bytes. + if _, err := rand.Read(bytes); err != nil { + return nil, err + } + + return bytes, nil +} + +// GenerateRandomStringURLSafe returns a URL-safe, base64 encoded +// securely generated random string. +// It will return an error if the system's secure random +// number generator fails to function correctly, in which +// case the caller should not continue. +func GenerateRandomStringURLSafe(n int) (string, error) { + b, err := GenerateRandomBytes(n) + + return base64.RawURLEncoding.EncodeToString(b), err +} + +// GenerateRandomStringDNSSafe returns a DNS-safe +// securely generated random string. +// It will return an error if the system's secure random +// number generator fails to function correctly, in which +// case the caller should not continue. +func GenerateRandomStringDNSSafe(size int) (string, error) { + var str string + var err error + for len(str) < size { + str, err = GenerateRandomStringURLSafe(size) + if err != nil { + return "", err + } + str = strings.ToLower( + strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", ""), + ) + } + + return str[:size], nil +} + +func IsStringInSlice(slice []string, str string) bool { + for _, s := range slice { + if s == str { + return true + } + } + + return false +} + +func TailNodesToString(nodes []*tailcfg.Node) string { + temp := make([]string, len(nodes)) + + for index, node := range nodes { + temp[index] = node.Name + } + + return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) +} + +func TailMapResponseToString(resp tailcfg.MapResponse) string { + return fmt.Sprintf( + "{ Node: %s, Peers: %s }", + resp.Node.Name, + TailNodesToString(resp.Peers), + ) +} diff --git a/hscontrol/util/string_test.go b/hscontrol/util/string_test.go new file mode 100644 index 0000000..87a8be1 --- /dev/null +++ b/hscontrol/util/string_test.go @@ -0,0 +1,15 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenerateRandomStringDNSSafe(t *testing.T) { + for i := 0; i < 100000; i++ { + str, err := GenerateRandomStringDNSSafe(8) + assert.Nil(t, err) + assert.Len(t, str, 8) + } +} diff --git a/hscontrol/utils.go b/hscontrol/utils.go deleted file mode 100644 index 9cfbf0c..0000000 --- a/hscontrol/utils.go +++ /dev/null @@ -1,361 +0,0 @@ -// Codehere is mostly taken from github.com/tailscale/tailscale -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package hscontrol - -import ( - "context" - "crypto/rand" - "encoding/base64" - "encoding/json" - "fmt" - "io/fs" - "net" - "net/netip" - "os" - "path/filepath" - "reflect" - "regexp" - "strconv" - "strings" - - "github.com/rs/zerolog/log" - "github.com/spf13/viper" - "go4.org/netipx" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -const ( - ErrCannotDecryptResponse = Error("cannot decrypt response") - ErrCouldNotAllocateIP = Error("could not find any suitable IP") - - // These constants are copied from the upstream tailscale.com/types/key - // library, because they are not exported. - // https://github.com/tailscale/tailscale/tree/main/types/key - - // nodePublicHexPrefix is the prefix used to identify a - // hex-encoded node public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - nodePublicHexPrefix = "nodekey:" - - // machinePublicHexPrefix is the prefix used to identify a - // hex-encoded machine public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - machinePublicHexPrefix = "mkey:" - - // discoPublicHexPrefix is the prefix used to identify a - // hex-encoded disco public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - discoPublicHexPrefix = "discokey:" - - // privateKey prefix. - privateHexPrefix = "privkey:" - - PermissionFallback = 0o700 - - ZstdCompression = "zstd" -) - -var NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+") - -func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string { - return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix) -} - -func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string { - return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix) -} - -func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string { - return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix) -} - -func MachinePublicKeyEnsurePrefix(machineKey string) string { - if !strings.HasPrefix(machineKey, machinePublicHexPrefix) { - return machinePublicHexPrefix + machineKey - } - - return machineKey -} - -func NodePublicKeyEnsurePrefix(nodeKey string) string { - if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) { - return nodePublicHexPrefix + nodeKey - } - - return nodeKey -} - -func DiscoPublicKeyEnsurePrefix(discoKey string) string { - if !strings.HasPrefix(discoKey, discoPublicHexPrefix) { - return discoPublicHexPrefix + discoKey - } - - return discoKey -} - -func PrivateKeyEnsurePrefix(privateKey string) string { - if !strings.HasPrefix(privateKey, privateHexPrefix) { - return privateHexPrefix + privateKey - } - - return privateKey -} - -// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors -type Error string - -func (e Error) Error() string { return string(e) } - -func decode( - msg []byte, - output interface{}, - pubKey *key.MachinePublic, - privKey *key.MachinePrivate, -) error { - log.Trace(). - Str("pubkey", pubKey.ShortString()). - Int("length", len(msg)). - Msg("Trying to decrypt") - - decrypted, ok := privKey.OpenFrom(*pubKey, msg) - if !ok { - return ErrCannotDecryptResponse - } - - if err := json.Unmarshal(decrypted, output); err != nil { - return err - } - - return nil -} - -func (h *Headscale) getAvailableIPs() (MachineAddresses, error) { - var ips MachineAddresses - var err error - ipPrefixes := h.cfg.IPPrefixes - for _, ipPrefix := range ipPrefixes { - var ip *netip.Addr - ip, err = h.getAvailableIP(ipPrefix) - if err != nil { - return ips, err - } - ips = append(ips, *ip) - } - - return ips, err -} - -func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { - var network, broadcast netip.Addr - ipRange := netipx.RangeOfPrefix(na) - network = ipRange.From() - broadcast = ipRange.To() - - return network, broadcast -} - -func (h *Headscale) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) { - usedIps, err := h.getUsedIPs() - if err != nil { - return nil, err - } - - ipPrefixNetworkAddress, ipPrefixBroadcastAddress := GetIPPrefixEndpoints(ipPrefix) - - // Get the first IP in our prefix - ip := ipPrefixNetworkAddress.Next() - - for { - if !ipPrefix.Contains(ip) { - return nil, ErrCouldNotAllocateIP - } - - switch { - case ip.Compare(ipPrefixBroadcastAddress) == 0: - fallthrough - case usedIps.Contains(ip): - fallthrough - case ip == netip.Addr{} || ip.IsLoopback(): - ip = ip.Next() - - continue - - default: - return &ip, nil - } - } -} - -func (h *Headscale) getUsedIPs() (*netipx.IPSet, error) { - // FIXME: This really deserves a better data model, - // but this was quick to get running and it should be enough - // to begin experimenting with a dual stack tailnet. - var addressesSlices []string - h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) - - var ips netipx.IPSetBuilder - for _, slice := range addressesSlices { - var machineAddresses MachineAddresses - err := machineAddresses.Scan(slice) - if err != nil { - return &netipx.IPSet{}, fmt.Errorf( - "failed to read ip from database: %w", - err, - ) - } - - for _, ip := range machineAddresses { - ips.Add(ip) - } - } - - ipSet, err := ips.IPSet() - if err != nil { - return &netipx.IPSet{}, fmt.Errorf( - "failed to build IP Set: %w", - err, - ) - } - - return ipSet, nil -} - -func tailNodesToString(nodes []*tailcfg.Node) string { - temp := make([]string, len(nodes)) - - for index, node := range nodes { - temp[index] = node.Name - } - - return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) -} - -func tailMapResponseToString(resp tailcfg.MapResponse) string { - return fmt.Sprintf( - "{ Node: %s, Peers: %s }", - resp.Node.Name, - tailNodesToString(resp.Peers), - ) -} - -func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { - var d net.Dialer - - return d.DialContext(ctx, "unix", addr) -} - -func stringToIPPrefix(prefixes []string) ([]netip.Prefix, error) { - result := make([]netip.Prefix, len(prefixes)) - - for index, prefixStr := range prefixes { - prefix, err := netip.ParsePrefix(prefixStr) - if err != nil { - return []netip.Prefix{}, err - } - - result[index] = prefix - } - - return result, nil -} - -func contains[T string | netip.Prefix](ts []T, t T) bool { - for _, v := range ts { - if reflect.DeepEqual(v, t) { - return true - } - } - - return false -} - -// GenerateRandomBytes returns securely generated random bytes. -// It will return an error if the system's secure random -// number generator fails to function correctly, in which -// case the caller should not continue. -func GenerateRandomBytes(n int) ([]byte, error) { - bytes := make([]byte, n) - - // Note that err == nil only if we read len(b) bytes. - if _, err := rand.Read(bytes); err != nil { - return nil, err - } - - return bytes, nil -} - -// GenerateRandomStringURLSafe returns a URL-safe, base64 encoded -// securely generated random string. -// It will return an error if the system's secure random -// number generator fails to function correctly, in which -// case the caller should not continue. -func GenerateRandomStringURLSafe(n int) (string, error) { - b, err := GenerateRandomBytes(n) - - return base64.RawURLEncoding.EncodeToString(b), err -} - -// GenerateRandomStringDNSSafe returns a DNS-safe -// securely generated random string. -// It will return an error if the system's secure random -// number generator fails to function correctly, in which -// case the caller should not continue. -func GenerateRandomStringDNSSafe(size int) (string, error) { - var str string - var err error - for len(str) < size { - str, err = GenerateRandomStringURLSafe(size) - if err != nil { - return "", err - } - str = strings.ToLower( - strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", ""), - ) - } - - return str[:size], nil -} - -func IsStringInSlice(slice []string, str string) bool { - for _, s := range slice { - if s == str { - return true - } - } - - return false -} - -func AbsolutePathFromConfigPath(path string) string { - // If a relative path is provided, prefix it with the directory where - // the config file was found. - if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) { - dir, _ := filepath.Split(viper.ConfigFileUsed()) - if dir != "" { - path = filepath.Join(dir, path) - } - } - - return path -} - -func GetFileMode(key string) fs.FileMode { - modeStr := viper.GetString(key) - - mode, err := strconv.ParseUint(modeStr, Base8, BitSize64) - if err != nil { - return PermissionFallback - } - - return fs.FileMode(mode) -} diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 8ad8f32..452f852 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" "github.com/ory/dockertest/v3" @@ -220,7 +221,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*hscontrol.OIDC } portNotation := fmt.Sprintf("%d/tcp", port) - hash, _ := hscontrol.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) + hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) hostname := fmt.Sprintf("hs-oidcmock-%s", hash) diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index be12808..e9183cd 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -6,7 +6,7 @@ import ( "net/url" "testing" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" @@ -110,7 +110,7 @@ func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv( return err } - hash, err := hscontrol.GenerateRandomStringDNSSafe(scenarioHashLength) + hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) if err != nil { return err } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 6b1652b..0051b40 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -24,6 +24,7 @@ import ( "github.com/davecgh/go-spew/spew" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/integrationutil" "github.com/ory/dockertest/v3" @@ -132,7 +133,7 @@ func WithHostPortBindings(bindings map[string][]string) Option { // in the Docker container name. func WithTestName(testName string) Option { return func(hsic *HeadscaleInContainer) { - hash, _ := hscontrol.GenerateRandomStringDNSSafe(hsicHashLength) + hash, _ := util.GenerateRandomStringDNSSafe(hsicHashLength) hostname := fmt.Sprintf("hs-%s-%s", testName, hash) hsic.hostname = hostname @@ -167,7 +168,7 @@ func New( network *dockertest.Network, opts ...Option, ) (*HeadscaleInContainer, error) { - hash, err := hscontrol.GenerateRandomStringDNSSafe(hsicHashLength) + hash, err := util.GenerateRandomStringDNSSafe(hsicHashLength) if err != nil { return nil, err } diff --git a/integration/scenario.go b/integration/scenario.go index 5800548..927d6c8 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -10,7 +10,7 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" @@ -105,7 +105,7 @@ type Scenario struct { // NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with // a set of Users and TailscaleClients. func NewScenario() (*Scenario, error) { - hash, err := hscontrol.GenerateRandomStringDNSSafe(scenarioHashLength) + hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) if err != nil { return nil, err } diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index cc285f3..ffc7e0a 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -12,7 +12,7 @@ import ( "time" "github.com/cenkalti/backoff/v4" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/integrationutil" "github.com/ory/dockertest/v3" @@ -150,7 +150,7 @@ func New( network *dockertest.Network, opts ...Option, ) (*TailscaleInContainer, error) { - hash, err := hscontrol.GenerateRandomStringDNSSafe(tsicHashLength) + hash, err := util.GenerateRandomStringDNSSafe(tsicHashLength) if err != nil { return nil, err }