diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go index 206209d..90dd51a 100644 --- a/cmd/headscale/cli/routes.go +++ b/cmd/headscale/cli/routes.go @@ -7,7 +7,7 @@ import ( "strconv" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/types" "github.com/pterm/pterm" "github.com/spf13/cobra" "google.golang.org/grpc/status" @@ -277,7 +277,7 @@ func routesToPtables(routes []*v1.Route) pterm.TableData { continue } - if prefix == hscontrol.ExitRouteV4 || prefix == hscontrol.ExitRouteV6 { + if prefix == types.ExitRouteV4 || prefix == types.ExitRouteV6 { isPrimaryStr = "-" } else { isPrimaryStr = strconv.FormatBool(route.IsPrimary) diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 2831dbf..5ce7816 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -10,6 +10,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc" @@ -41,13 +42,15 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) { if cfg.ACL.PolicyPath != "" { aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) - err = app.LoadACLPolicyFromPath(aclPath) + pol, err := policy.LoadACLPolicyFromPath(aclPath) if err != nil { log.Fatal(). Str("path", aclPath). Err(err). Msg("Could not load the ACL policy") } + + app.ACLPolicy = pol } return app, nil diff --git a/hscontrol/api.go b/hscontrol/api.go index 8e30141..4a43aeb 100644 --- a/hscontrol/api.go +++ b/hscontrol/api.go @@ -18,9 +18,6 @@ const ( // TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. registrationHoldoff = time.Second * 5 reservedResponseHeaderSize = 4 - RegisterMethodAuthKey = "authkey" - RegisterMethodOIDC = "oidc" - RegisterMethodCLI = "cli" ) var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( @@ -56,7 +53,7 @@ func (h *Headscale) HealthHandler( } } - if err := h.db.pingDB(req.Context()); err != nil { + if err := h.db.PingDB(req.Context()); err != nil { respond(err) return diff --git a/hscontrol/api_common.go b/hscontrol/api_common.go index f1b3fd8..4d40c1d 100644 --- a/hscontrol/api_common.go +++ b/hscontrol/api_common.go @@ -3,6 +3,7 @@ package hscontrol import ( "time" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" @@ -10,13 +11,13 @@ import ( func (h *Headscale) generateMapResponse( mapRequest tailcfg.MapRequest, - machine *Machine, + machine *types.Machine, ) (*tailcfg.MapResponse, error) { log.Trace(). Str("func", "generateMapResponse"). Str("machine", mapRequest.Hostinfo.Hostname). Msg("Creating Map response") - node, err := h.db.toNode(*machine, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) + node, err := h.db.TailNode(*machine, h.ACLPolicy, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). @@ -27,7 +28,7 @@ func (h *Headscale) generateMapResponse( return nil, err } - peers, err := h.db.getValidPeers(h.aclPolicy, h.aclRules, machine) + peers, err := h.db.GetValidPeers(h.aclRules, machine) if err != nil { log.Error(). Caller(). @@ -38,9 +39,9 @@ func (h *Headscale) generateMapResponse( return nil, err } - profiles := h.db.getMapResponseUserProfiles(*machine, peers) + profiles := h.db.GetMapResponseUserProfiles(*machine, peers) - nodePeers, err := h.db.toNodes(peers, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) + nodePeers, err := h.db.TailNodes(peers, h.ACLPolicy, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). diff --git a/hscontrol/app.go b/hscontrol/app.go index 38d4ec8..bb68ced 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -23,6 +23,9 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" @@ -73,7 +76,7 @@ const ( // Headscale represents the base app of the service. type Headscale struct { cfg *Config - db *HSDatabase + db *db.HSDatabase dbString string dbType string dbDebug bool @@ -83,7 +86,7 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *DERPServer - aclPolicy *ACLPolicy + ACLPolicy *policy.ACLPolicy aclRules []tailcfg.FilterRule sshPolicy *tailcfg.SSHPolicy @@ -99,6 +102,12 @@ type Headscale struct { stateUpdateChan chan struct{} cancelStateUpdateChan chan struct{} + + // TODO(kradalby): Temporary measure to make sure we can update policy + // across modules, will be removed when aclRules are no longer stored + // globally but generated per node basis. + policyUpdateChan chan struct{} + cancelPolicyUpdateChan chan struct{} } func NewHeadscale(cfg *Config) (*Headscale, error) { @@ -119,7 +128,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { var dbString string switch cfg.DBtype { - case Postgres: + case db.Postgres: dbString = fmt.Sprintf( "host=%s dbname=%s user=%s", cfg.DBhost, @@ -142,7 +151,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { if cfg.DBpass != "" { dbString += fmt.Sprintf(" password=%s", cfg.DBpass) } - case Sqlite: + case db.Sqlite: dbString = cfg.DBpath default: return nil, errUnsupportedDatabase @@ -166,23 +175,28 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { stateUpdateChan: make(chan struct{}), cancelStateUpdateChan: make(chan struct{}), + + policyUpdateChan: make(chan struct{}), + cancelPolicyUpdateChan: make(chan struct{}), } go app.watchStateChannel() + go app.watchPolicyChannel() - db, err := NewHeadscaleDatabase( + database, err := db.NewHeadscaleDatabase( cfg.DBtype, dbString, cfg.OIDC.StripEmaildomain, app.dbDebug, app.stateUpdateChan, + app.policyUpdateChan, cfg.IPPrefixes, cfg.BaseDomain) if err != nil { return nil, err } - app.db = db + app.db = database if cfg.OIDC.Issuer != "" { err = app.initOIDC() @@ -228,7 +242,7 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) for range ticker.C { - h.expireEphemeralNodesWorker() + h.db.ExpireEphemeralMachines(h.cfg.EphemeralNodeInactivityTimeout) } } @@ -237,112 +251,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { func (h *Headscale) expireExpiredMachines(milliSeconds int64) { ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) for range ticker.C { - h.expireExpiredMachinesWorker() + h.db.ExpireExpiredMachines(h.getLastStateChange()) } } func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) { ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) for range ticker.C { - err := h.db.handlePrimarySubnetFailover() + err := h.db.HandlePrimarySubnetFailover() if err != nil { log.Error().Err(err).Msg("failed to handle primary subnet failover") } } } -func (h *Headscale) expireEphemeralNodesWorker() { - users, err := h.db.ListUsers() - if err != nil { - log.Error().Err(err).Msg("Error listing users") - - return - } - - for _, user := range users { - machines, err := h.db.ListMachinesByUser(user.Name) - if err != nil { - log.Error(). - Err(err). - Str("user", user.Name). - Msg("Error listing machines in user") - - return - } - - expiredFound := false - for _, machine := range machines { - if machine.isEphemeral() && machine.LastSeen != nil && - time.Now(). - After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { - expiredFound = true - log.Info(). - Str("machine", machine.Hostname). - Msg("Ephemeral client removed from database") - - err = h.db.db.Unscoped().Delete(machine).Error - if err != nil { - log.Error(). - Err(err). - Str("machine", machine.Hostname). - Msg("🤮 Cannot delete ephemeral machine from the database") - } - } - } - - if expiredFound { - h.setLastStateChangeToNow() - } - } -} - -func (h *Headscale) expireExpiredMachinesWorker() { - users, err := h.db.ListUsers() - if err != nil { - log.Error().Err(err).Msg("Error listing users") - - return - } - - for _, user := range users { - machines, err := h.db.ListMachinesByUser(user.Name) - if err != nil { - log.Error(). - Err(err). - Str("user", user.Name). - Msg("Error listing machines in user") - - return - } - - expiredFound := false - for index, machine := range machines { - if machine.isExpired() && - machine.Expiry.After(h.getLastStateChange(user)) { - expiredFound = true - - err := h.db.ExpireMachine(&machines[index]) - if err != nil { - log.Error(). - Err(err). - Str("machine", machine.Hostname). - Str("name", machine.GivenName). - Msg("🤮 Cannot expire machine") - } else { - log.Info(). - Str("machine", machine.Hostname). - Str("name", machine.GivenName). - Msg("Machine successfully expired") - } - } - } - - if expiredFound { - h.setLastStateChangeToNow() - } - } -} - func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, @@ -565,6 +487,8 @@ func (h *Headscale) Serve() error { go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel) } + // TODO(kradalby): These should have cancel channels and be cleaned + // up on shutdown. go h.expireEphemeralNodes(updateInterval) go h.expireExpiredMachines(updateInterval) @@ -774,10 +698,12 @@ func (h *Headscale) Serve() error { if h.cfg.ACL.PolicyPath != "" { aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) - err := h.LoadACLPolicyFromPath(aclPath) + pol, err := policy.LoadACLPolicyFromPath(aclPath) if err != nil { log.Error().Err(err).Msg("Failed to reload ACL policy") } + + h.ACLPolicy = pol log.Info(). Str("path", aclPath). Msg("ACL policy successfully reloaded, notifying nodes of change") @@ -824,12 +750,12 @@ func (h *Headscale) Serve() error { close(h.stateUpdateChan) close(h.cancelStateUpdateChan) + <-h.cancelPolicyUpdateChan + close(h.policyUpdateChan) + close(h.cancelPolicyUpdateChan) + // Close db connections - db, err := h.db.db.DB() - if err != nil { - log.Error().Err(err).Msg("Failed to get db handle") - } - err = db.Close() + err = h.db.Close() if err != nil { log.Error().Err(err).Msg("Failed to close db") } @@ -936,6 +862,30 @@ func (h *Headscale) watchStateChannel() { } } +// TODO(kradalby): baby steps, make this more robust. +func (h *Headscale) watchPolicyChannel() { + for { + select { + case <-h.policyUpdateChan: + machines, err := h.db.ListMachines() + if err != nil { + log.Error().Err(err).Msg("failed to fetch machines during policy update") + } + + rules, sshPolicy, err := policy.GenerateFilterRules(h.ACLPolicy, machines, h.cfg.OIDC.StripEmaildomain) + if err != nil { + log.Error().Err(err).Msg("failed to update ACL rules") + } + + h.aclRules = rules + h.sshPolicy = sshPolicy + + case <-h.cancelPolicyUpdateChan: + return + } + } +} + func (h *Headscale) setLastStateChangeToNow() { var err error @@ -958,7 +908,7 @@ func (h *Headscale) setLastStateChangeToNow() { } } -func (h *Headscale) getLastStateChange(users ...User) time.Time { +func (h *Headscale) getLastStateChange(users ...types.User) time.Time { times := []time.Time{} // getLastStateChange takes a list of users as a "filter", if no users diff --git a/hscontrol/db/acls_test.go b/hscontrol/db/acls_test.go new file mode 100644 index 0000000..884b6c5 --- /dev/null +++ b/hscontrol/db/acls_test.go @@ -0,0 +1,480 @@ +package db + +import ( + "net/netip" + + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "gopkg.in/check.v1" + "tailscale.com/envknob" + "tailscale.com/tailcfg" +) + +// TODO(kradalby): +// Convert these tests to being non-database dependent and table driven. They are +// very verbose, and dont really need the database. + +func (s *Suite) TestSshRules(c *check.C) { + envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1") + + user, err := db.CreateUser("user1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "testmachine") + c.Assert(err, check.NotNil) + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "testmachine", + RequestTags: []string{"tag:test"}, + } + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + aclPolicy := &policy.ACLPolicy{ + Groups: policy.Groups{ + "group:test": []string{"user1"}, + }, + Hosts: policy.Hosts{ + "client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32), + }, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + SSHs: []policy.SSH{ + { + Action: "accept", + Sources: []string{"group:test"}, + Destinations: []string{"client"}, + Users: []string{"autogroup:nonroot"}, + }, + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"client"}, + Users: []string{"autogroup:nonroot"}, + }, + }, + } + + _, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false) + + c.Assert(err, check.IsNil) + c.Assert(sshPolicy, check.NotNil) + c.Assert(sshPolicy.Rules, check.HasLen, 2) + c.Assert(sshPolicy.Rules[0].SSHUsers, check.HasLen, 1) + c.Assert(sshPolicy.Rules[0].Principals, check.HasLen, 1) + c.Assert(sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1") + + c.Assert(sshPolicy.Rules[1].SSHUsers, check.HasLen, 1) + c.Assert(sshPolicy.Rules[1].Principals, check.HasLen, 1) + c.Assert(sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*") +} + +// this test should validate that we can expand a group in a TagOWner section and +// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. +// the tag is matched in the Sources section. +func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { + user, err := db.CreateUser("user1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "testmachine") + c.Assert(err, check.NotNil) + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "testmachine", + RequestTags: []string{"tag:test"}, + } + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + pol := &policy.ACLPolicy{ + Groups: policy.Groups{"group:test": []string{"user1", "user2"}}, + TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"tag:test"}, + Destinations: []string{"*:*"}, + }, + }, + } + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") +} + +// this test should validate that we can expand a group in a TagOWner section and +// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. +// the tag is matched in the Destinations section. +func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { + user, err := db.CreateUser("user1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "testmachine") + c.Assert(err, check.NotNil) + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "testmachine", + RequestTags: []string{"tag:test"}, + } + + machine := types.Machine{ + ID: 1, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + pol := &policy.ACLPolicy{ + Groups: policy.Groups{"group:test": []string{"user1", "user2"}}, + TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"tag:test:*"}, + }, + }, + } + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].DstPorts, check.HasLen, 1) + c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") +} + +// need a test with: +// tag on a host that isn't owned by a tag owners. So the user +// of the host should be valid. +func (s *Suite) TestInvalidTagValidUser(c *check.C) { + user, err := db.CreateUser("user1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "testmachine") + c.Assert(err, check.NotNil) + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "testmachine", + RequestTags: []string{"tag:foo"}, + } + + machine := types.Machine{ + ID: 1, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + pol := &policy.ACLPolicy{ + TagOwners: policy.TagOwners{"tag:test": []string{"user1"}}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"user1"}, + Destinations: []string{"*:*"}, + }, + }, + } + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") +} + +// tag on a host is owned by a tag owner, the tag is valid. +// an ACL rule is matching the tag to a user. It should not be valid since the +// host should be tied to the tag now. +func (s *Suite) TestValidTagInvalidUser(c *check.C) { + user, err := db.CreateUser("user1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "webserver") + c.Assert(err, check.NotNil) + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "webserver", + RequestTags: []string{"tag:webapp"}, + } + + machine := types.Machine{ + ID: 1, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "webserver", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "user") + hostInfo2 := tailcfg.Hostinfo{ + OS: "debian", + Hostname: "Hostname", + } + c.Assert(err, check.NotNil) + machine = types.Machine{ + ID: 2, + MachineKey: "56789", + NodeKey: "bar2", + DiscoKey: "faab", + Hostname: "user", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo2), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + pol := &policy.ACLPolicy{ + TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"user1"}, + Destinations: []string{"tag:webapp:80,443"}, + }, + }, + } + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32") + c.Assert(rules[0].DstPorts, check.HasLen, 2) + c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80)) + c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80)) + c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") + c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443)) + c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443)) + c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32") +} + +func (s *Suite) TestPortUser(c *check.C) { + user, err := db.CreateUser("testuser") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("testuser", "testmachine") + c.Assert(err, check.NotNil) + ips, _ := db.getAvailableIPs() + machine := types.Machine{ + ID: 0, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: ips, + AuthKeyID: uint(pak.ID), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + acl := []byte(` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "testuser", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} + `) + pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") + c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(err, check.IsNil) + c.Assert(rules, check.NotNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].DstPorts, check.HasLen, 1) + c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) + c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") + c.Assert(len(ips), check.Equals, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") +} + +func (s *Suite) TestPortGroup(c *check.C) { + user, err := db.CreateUser("testuser") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("testuser", "testmachine") + c.Assert(err, check.NotNil) + ips, _ := db.getAvailableIPs() + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: ips, + AuthKeyID: uint(pak.ID), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + acl := []byte(` +{ + "groups": { + "group:example": [ + "testuser", + ], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} + `) + pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") + c.Assert(err, check.IsNil) + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(rules, check.NotNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].DstPorts, check.HasLen, 1) + c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) + c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") + c.Assert(len(ips), check.Equals, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") +} diff --git a/hscontrol/addresses.go b/hscontrol/db/addresses.go similarity index 87% rename from hscontrol/addresses.go rename to hscontrol/db/addresses.go index 7f78935..1a7d35d 100644 --- a/hscontrol/addresses.go +++ b/hscontrol/db/addresses.go @@ -3,21 +3,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package hscontrol +package db import ( "errors" "fmt" "net/netip" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" ) var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") -func (hsdb *HSDatabase) getAvailableIPs() (MachineAddresses, error) { - var ips MachineAddresses +func (hsdb *HSDatabase) getAvailableIPs() (types.MachineAddresses, error) { + var ips types.MachineAddresses var err error for _, ipPrefix := range hsdb.ipPrefixes { var ip *netip.Addr @@ -68,11 +69,11 @@ func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) { // but this was quick to get running and it should be enough // to begin experimenting with a dual stack tailnet. var addressesSlices []string - hsdb.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) + hsdb.db.Model(&types.Machine{}).Pluck("ip_addresses", &addressesSlices) var ips netipx.IPSetBuilder for _, slice := range addressesSlices { - var machineAddresses MachineAddresses + var machineAddresses types.MachineAddresses err := machineAddresses.Scan(slice) if err != nil { return &netipx.IPSet{}, fmt.Errorf( diff --git a/hscontrol/addresses_test.go b/hscontrol/db/addresses_test.go similarity index 71% rename from hscontrol/addresses_test.go rename to hscontrol/db/addresses_test.go index f3be93a..1289148 100644 --- a/hscontrol/addresses_test.go +++ b/hscontrol/db/addresses_test.go @@ -1,14 +1,16 @@ -package hscontrol +package db import ( "net/netip" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" "gopkg.in/check.v1" ) func (s *Suite) TestGetAvailableIp(c *check.C) { - ips, err := app.db.getAvailableIPs() + ips, err := db.getAvailableIPs() c.Assert(err, check.IsNil) @@ -19,32 +21,32 @@ func (s *Suite) TestGetAvailableIp(c *check.C) { } func (s *Suite) TestGetUsedIps(c *check.C) { - ips, err := app.db.getAvailableIPs() + ips, err := db.getAvailableIPs() c.Assert(err, check.IsNil) - user, err := app.db.CreateUser("test-ip") + user, err := db.CreateUser("test-ip") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "testmachine") + _, err = db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) - machine := Machine{ + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testmachine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), IPAddresses: ips, } - app.db.db.Save(&machine) + db.db.Save(&machine) - usedIps, err := app.db.getUsedIPs() + usedIps, err := db.getUsedIPs() c.Assert(err, check.IsNil) @@ -56,46 +58,48 @@ func (s *Suite) TestGetUsedIps(c *check.C) { c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true) c.Assert(usedIps.Contains(expected), check.Equals, true) - machine1, err := app.db.GetMachineByID(0) + machine1, err := db.GetMachineByID(0) c.Assert(err, check.IsNil) c.Assert(len(machine1.IPAddresses), check.Equals, 1) c.Assert(machine1.IPAddresses[0], check.Equals, expected) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetMultiIp(c *check.C) { - user, err := app.db.CreateUser("test-ip-multi") + user, err := db.CreateUser("test-ip-multi") c.Assert(err, check.IsNil) for index := 1; index <= 350; index++ { - app.db.ipAllocationMutex.Lock() + db.ipAllocationMutex.Lock() - ips, err := app.db.getAvailableIPs() + ips, err := db.getAvailableIPs() c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "testmachine") + _, err = db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) - machine := Machine{ + machine := types.Machine{ ID: uint64(index), MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testmachine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), IPAddresses: ips, } - app.db.db.Save(&machine) + db.db.Save(&machine) - app.db.ipAllocationMutex.Unlock() + db.ipAllocationMutex.Unlock() } - usedIps, err := app.db.getUsedIPs() + usedIps, err := db.getUsedIPs() c.Assert(err, check.IsNil) expected0 := netip.MustParseAddr("10.27.0.1") @@ -117,7 +121,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { c.Assert(usedIps.Contains(expected300), check.Equals, true) // Check that we can read back the IPs - machine1, err := app.db.GetMachineByID(1) + machine1, err := db.GetMachineByID(1) c.Assert(err, check.IsNil) c.Assert(len(machine1.IPAddresses), check.Equals, 1) c.Assert( @@ -126,7 +130,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { netip.MustParseAddr("10.27.0.1"), ) - machine50, err := app.db.GetMachineByID(50) + machine50, err := db.GetMachineByID(50) c.Assert(err, check.IsNil) c.Assert(len(machine50.IPAddresses), check.Equals, 1) c.Assert( @@ -136,7 +140,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { ) expectedNextIP := netip.MustParseAddr("10.27.1.95") - nextIP, err := app.db.getAvailableIPs() + nextIP, err := db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(nextIP), check.Equals, 1) @@ -144,15 +148,17 @@ func (s *Suite) TestGetMultiIp(c *check.C) { // If we call get Available again, we should receive // the same IP, as it has not been reserved. - nextIP2, err := app.db.getAvailableIPs() + nextIP2, err := db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(nextIP2), check.Equals, 1) c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String()) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { - ips, err := app.db.getAvailableIPs() + ips, err := db.getAvailableIPs() c.Assert(err, check.IsNil) expected := netip.MustParseAddr("10.27.0.1") @@ -160,30 +166,32 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { c.Assert(len(ips), check.Equals, 1) c.Assert(ips[0].String(), check.Equals, expected.String()) - user, err := app.db.CreateUser("test-ip") + user, err := db.CreateUser("test-ip") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "testmachine") + _, err = db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) - machine := Machine{ + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testmachine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.db.Save(&machine) + db.db.Save(&machine) - ips2, err := app.db.getAvailableIPs() + ips2, err := db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(ips2), check.Equals, 1) c.Assert(ips2[0].String(), check.Equals, expected.String()) + + c.Assert(channelUpdates, check.Equals, int32(0)) } diff --git a/hscontrol/api_key.go b/hscontrol/db/api_key.go similarity index 64% rename from hscontrol/api_key.go rename to hscontrol/db/api_key.go index bf2ccf3..4e4030e 100644 --- a/hscontrol/api_key.go +++ b/hscontrol/db/api_key.go @@ -1,4 +1,4 @@ -package hscontrol +package db import ( "errors" @@ -6,10 +6,9 @@ import ( "strings" "time" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "golang.org/x/crypto/bcrypt" - "google.golang.org/protobuf/types/known/timestamppb" ) const ( @@ -19,22 +18,10 @@ const ( var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") -// APIKey describes the datamodel for API keys used to remotely authenticate with -// headscale. -type APIKey struct { - ID uint64 `gorm:"primary_key"` - Prefix string `gorm:"uniqueIndex"` - Hash []byte - - CreatedAt *time.Time - Expiration *time.Time - LastSeen *time.Time -} - // CreateAPIKey creates a new ApiKey in a user, and returns it. func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, -) (string, *APIKey, error) { +) (string, *types.APIKey, error) { prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) if err != nil { return "", nil, err @@ -53,7 +40,7 @@ func (hsdb *HSDatabase) CreateAPIKey( return "", nil, err } - key := APIKey{ + key := types.APIKey{ Prefix: prefix, Hash: hash, Expiration: expiration, @@ -67,8 +54,8 @@ func (hsdb *HSDatabase) CreateAPIKey( } // ListAPIKeys returns the list of ApiKeys for a user. -func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) { - keys := []APIKey{} +func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { + keys := []types.APIKey{} if err := hsdb.db.Find(&keys).Error; err != nil { return nil, err } @@ -77,8 +64,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) { } // GetAPIKey returns a ApiKey for a given key. -func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) { - key := APIKey{} +func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { + key := types.APIKey{} if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error } @@ -87,9 +74,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) { } // GetAPIKeyByID returns a ApiKey for a given id. -func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) { - key := APIKey{} - if result := hsdb.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil { +func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { + key := types.APIKey{} + if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error } @@ -98,7 +85,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) { // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // does not exist. -func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error { +func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { return result.Error } @@ -107,7 +94,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error { } // ExpireAPIKey marks a ApiKey as expired. -func (hsdb *HSDatabase) ExpireAPIKey(key *APIKey) error { +func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -136,24 +123,3 @@ func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { return true, nil } - -func (key *APIKey) toProto() *v1.ApiKey { - protoKey := v1.ApiKey{ - Id: key.ID, - Prefix: key.Prefix, - } - - if key.Expiration != nil { - protoKey.Expiration = timestamppb.New(*key.Expiration) - } - - if key.CreatedAt != nil { - protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) - } - - if key.LastSeen != nil { - protoKey.LastSeen = timestamppb.New(*key.LastSeen) - } - - return &protoKey -} diff --git a/hscontrol/api_key_test.go b/hscontrol/db/api_key_test.go similarity index 63% rename from hscontrol/api_key_test.go rename to hscontrol/db/api_key_test.go index 007b5d1..0fc42c5 100644 --- a/hscontrol/api_key_test.go +++ b/hscontrol/db/api_key_test.go @@ -1,4 +1,4 @@ -package hscontrol +package db import ( "time" @@ -7,7 +7,7 @@ import ( ) func (*Suite) TestCreateAPIKey(c *check.C) { - apiKeyStr, apiKey, err := app.db.CreateAPIKey(nil) + apiKeyStr, apiKey, err := db.CreateAPIKey(nil) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) @@ -16,74 +16,82 @@ func (*Suite) TestCreateAPIKey(c *check.C) { c.Assert(apiKey.Hash, check.NotNil) c.Assert(apiKeyStr, check.Not(check.Equals), "") - _, err = app.db.ListAPIKeys() + _, err = db.ListAPIKeys() c.Assert(err, check.IsNil) - keys, err := app.db.ListAPIKeys() + keys, err := db.ListAPIKeys() c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { - key, err := app.db.GetAPIKey("does-not-exist") + key, err := db.GetAPIKey("does-not-exist") c.Assert(err, check.NotNil) c.Assert(key, check.IsNil) } func (*Suite) TestValidateAPIKeyOk(c *check.C) { nowPlus2 := time.Now().Add(2 * time.Hour) - apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) + apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.db.ValidateAPIKey(apiKeyStr) + valid, err := db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, true) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour) - apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowMinus2) + apiKeyStr, apiKey, err := db.CreateAPIKey(&nowMinus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.db.ValidateAPIKey(apiKeyStr) + valid, err := db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, false) now := time.Now() - apiKeyStrNow, apiKey, err := app.db.CreateAPIKey(&now) + apiKeyStrNow, apiKey, err := db.CreateAPIKey(&now) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - validNow, err := app.db.ValidateAPIKey(apiKeyStrNow) + validNow, err := db.ValidateAPIKey(apiKeyStrNow) c.Assert(err, check.IsNil) c.Assert(validNow, check.Equals, false) - validSilly, err := app.db.ValidateAPIKey("nota.validkey") + validSilly, err := db.ValidateAPIKey("nota.validkey") c.Assert(err, check.NotNil) c.Assert(validSilly, check.Equals, false) - validWithErr, err := app.db.ValidateAPIKey("produceerrorkey") + validWithErr, err := db.ValidateAPIKey("produceerrorkey") c.Assert(err, check.NotNil) c.Assert(validWithErr, check.Equals, false) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (*Suite) TestExpireAPIKey(c *check.C) { nowPlus2 := time.Now().Add(2 * time.Hour) - apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) + apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.db.ValidateAPIKey(apiKeyStr) + valid, err := db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, true) - err = app.db.ExpireAPIKey(apiKey) + err = db.ExpireAPIKey(apiKey) c.Assert(err, check.IsNil) c.Assert(apiKey.Expiration, check.NotNil) - notValid, err := app.db.ValidateAPIKey(apiKeyStr) + notValid, err := db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(notValid, check.Equals, false) + + c.Assert(channelUpdates, check.Equals, int32(0)) } diff --git a/hscontrol/db.go b/hscontrol/db/db.go similarity index 63% rename from hscontrol/db.go rename to hscontrol/db/db.go index e80a3c3..bc6de08 100644 --- a/hscontrol/db.go +++ b/hscontrol/db/db.go @@ -1,9 +1,7 @@ -package hscontrol +package db import ( "context" - "database/sql/driver" - "encoding/json" "errors" "fmt" "net/netip" @@ -11,11 +9,12 @@ import ( "time" "github.com/glebarez/sqlite" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" - "tailscale.com/tailcfg" ) const ( @@ -26,7 +25,6 @@ const ( var ( errValueNotFound = errors.New("not found") - ErrCannotParsePrefix = errors.New("cannot parse prefix") errDatabaseNotSupported = errors.New("database type not supported") ) @@ -38,8 +36,9 @@ type KV struct { } type HSDatabase struct { - db *gorm.DB - notifyStateChan chan<- struct{} + db *gorm.DB + notifyStateChan chan<- struct{} + notifyPolicyChan chan<- struct{} ipAllocationMutex sync.Mutex @@ -54,6 +53,7 @@ func NewHeadscaleDatabase( dbType, connectionAddr string, stripEmailDomain, debug bool, notifyStateChan chan<- struct{}, + notifyPolicyChan chan<- struct{}, ipPrefixes []netip.Prefix, baseDomain string, ) (*HSDatabase, error) { @@ -63,8 +63,9 @@ func NewHeadscaleDatabase( } db := HSDatabase{ - db: dbConn, - notifyStateChan: notifyStateChan, + db: dbConn, + notifyStateChan: notifyStateChan, + notifyPolicyChan: notifyPolicyChan, ipPrefixes: ipPrefixes, baseDomain: baseDomain, @@ -79,30 +80,30 @@ func NewHeadscaleDatabase( _ = dbConn.Migrator().RenameTable("namespaces", "users") - err = dbConn.AutoMigrate(User{}) + err = dbConn.AutoMigrate(types.User{}) if err != nil { return nil, err } - _ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id") - _ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id") + _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "namespace_id", "user_id") + _ = dbConn.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id") - _ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses") - _ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname") + _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "ip_address", "ip_addresses") + _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "name", "hostname") // GivenName is used as the primary source of DNS names, make sure // the field is populated and normalized if it was not when the // machine was registered. - _ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name") + _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "nickname", "given_name") // If the Machine table has a column for registered, // find all occourences of "false" and drop them. Then // remove the column. - if dbConn.Migrator().HasColumn(&Machine{}, "registered") { + if dbConn.Migrator().HasColumn(&types.Machine{}, "registered") { log.Info(). Msg(`Database has legacy "registered" column in machine, removing...`) - machines := Machines{} + machines := types.Machines{} if err := dbConn.Not("registered").Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") } @@ -112,7 +113,7 @@ func NewHeadscaleDatabase( Str("machine", machine.Hostname). Str("machine_key", machine.MachineKey). Msg("Deleting unregistered machine") - if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil { + if err := dbConn.Delete(&types.Machine{}, machine.ID).Error; err != nil { log.Error(). Err(err). Str("machine", machine.Hostname). @@ -121,23 +122,23 @@ func NewHeadscaleDatabase( } } - err := dbConn.Migrator().DropColumn(&Machine{}, "registered") + err := dbConn.Migrator().DropColumn(&types.Machine{}, "registered") if err != nil { log.Error().Err(err).Msg("Error dropping registered column") } } - err = dbConn.AutoMigrate(&Route{}) + err = dbConn.AutoMigrate(&types.Route{}) if err != nil { return nil, err } - if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") { + if dbConn.Migrator().HasColumn(&types.Machine{}, "enabled_routes") { log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...") type MachineAux struct { ID uint64 - EnabledRoutes IPPrefixes + EnabledRoutes types.IPPrefixes } machinesAux := []MachineAux{} @@ -157,8 +158,8 @@ func NewHeadscaleDatabase( } err = dbConn.Preload("Machine"). - Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). - First(&Route{}). + Where("machine_id = ? AND prefix = ?", machine.ID, types.IPPrefix(prefix)). + First(&types.Route{}). Error if err == nil { log.Info(). @@ -168,11 +169,11 @@ func NewHeadscaleDatabase( continue } - route := Route{ + route := types.Route{ MachineID: machine.ID, Advertised: true, Enabled: true, - Prefix: IPPrefix(prefix), + Prefix: types.IPPrefix(prefix), } if err := dbConn.Create(&route).Error; err != nil { log.Error().Err(err).Msg("Error creating route") @@ -185,26 +186,26 @@ func NewHeadscaleDatabase( } } - err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes") + err = dbConn.Migrator().DropColumn(&types.Machine{}, "enabled_routes") if err != nil { log.Error().Err(err).Msg("Error dropping enabled_routes column") } } - err = dbConn.AutoMigrate(&Machine{}) + err = dbConn.AutoMigrate(&types.Machine{}) if err != nil { return nil, err } - if dbConn.Migrator().HasColumn(&Machine{}, "given_name") { - machines := Machines{} + if dbConn.Migrator().HasColumn(&types.Machine{}, "given_name") { + machines := types.Machines{} if err := dbConn.Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") } for item, machine := range machines { if machine.GivenName == "" { - normalizedHostname, err := NormalizeToFQDNRules( + normalizedHostname, err := util.NormalizeToFQDNRules( machine.Hostname, stripEmailDomain, ) @@ -233,19 +234,19 @@ func NewHeadscaleDatabase( return nil, err } - err = dbConn.AutoMigrate(&PreAuthKey{}) + err = dbConn.AutoMigrate(&types.PreAuthKey{}) if err != nil { return nil, err } - err = dbConn.AutoMigrate(&PreAuthKeyACLTag{}) + err = dbConn.AutoMigrate(&types.PreAuthKeyACLTag{}) if err != nil { return nil, err } _ = dbConn.Migrator().DropTable("shared_machines") - err = dbConn.AutoMigrate(&APIKey{}) + err = dbConn.AutoMigrate(&types.APIKey{}) if err != nil { return nil, err } @@ -339,7 +340,7 @@ func (hsdb *HSDatabase) setValue(key string, value string) error { return nil } -func (hsdb *HSDatabase) pingDB(ctx context.Context) error { +func (hsdb *HSDatabase) PingDB(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() sqlDB, err := hsdb.db.DB() @@ -350,97 +351,11 @@ func (hsdb *HSDatabase) pingDB(ctx context.Context) error { return sqlDB.PingContext(ctx) } -// This is a "wrapper" type around tailscales -// Hostinfo to allow us to add database "serialization" -// methods. This allows us to use a typed values throughout -// the code and not have to marshal/unmarshal and error -// check all over the code. -type HostInfo tailcfg.Hostinfo - -func (hi *HostInfo) Scan(destination interface{}) error { - switch value := destination.(type) { - case []byte: - return json.Unmarshal(value, hi) - - case string: - return json.Unmarshal([]byte(value), hi) - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) +func (hsdb *HSDatabase) Close() error { + db, err := hsdb.db.DB() + if err != nil { + return err } -} - -// Value return json value, implement driver.Valuer interface. -func (hi HostInfo) Value() (driver.Value, error) { - bytes, err := json.Marshal(hi) - - return string(bytes), err -} - -type IPPrefix netip.Prefix - -func (i *IPPrefix) Scan(destination interface{}) error { - switch value := destination.(type) { - case string: - prefix, err := netip.ParsePrefix(value) - if err != nil { - return err - } - *i = IPPrefix(prefix) - - return nil - default: - return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (i IPPrefix) Value() (driver.Value, error) { - prefixStr := netip.Prefix(i).String() - - return prefixStr, nil -} - -type IPPrefixes []netip.Prefix - -func (i *IPPrefixes) Scan(destination interface{}) error { - switch value := destination.(type) { - case []byte: - return json.Unmarshal(value, i) - - case string: - return json.Unmarshal([]byte(value), i) - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (i IPPrefixes) Value() (driver.Value, error) { - bytes, err := json.Marshal(i) - - return string(bytes), err -} - -type StringList []string - -func (i *StringList) Scan(destination interface{}) error { - switch value := destination.(type) { - case []byte: - return json.Unmarshal(value, i) - - case string: - return json.Unmarshal([]byte(value), i) - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (i StringList) Value() (driver.Value, error) { - bytes, err := json.Marshal(i) - - return string(bytes), err + + return db.Close() } diff --git a/hscontrol/machine.go b/hscontrol/db/machine.go similarity index 58% rename from hscontrol/machine.go rename to hscontrol/db/machine.go index 846112b..a8d3569 100644 --- a/hscontrol/machine.go +++ b/hscontrol/db/machine.go @@ -1,7 +1,6 @@ -package hscontrol +package db import ( - "database/sql/driver" "errors" "fmt" "net/netip" @@ -10,13 +9,12 @@ import ( "strings" "time" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "github.com/samber/lo" - "go4.org/netipx" - "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -25,13 +23,12 @@ import ( const ( MachineGivenNameHashLength = 8 MachineGivenNameTrimSize = 2 - maxHostnameLength = 255 + MaxHostnameLength = 255 ) var ( ErrMachineNotFound = errors.New("machine not found") ErrMachineRouteIsNotAvailable = errors.New("route is not available on machine") - ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses") ErrMachineNotFoundRegistrationCache = errors.New( "machine not found in registration cache", ) @@ -42,193 +39,27 @@ var ( ) ) -// Machine is a Headscale client. -type Machine struct { - ID uint64 `gorm:"primary_key"` - MachineKey string `gorm:"type:varchar(64);unique_index"` - NodeKey string - DiscoKey string - IPAddresses MachineAddresses - - // Hostname represents the name given by the Tailscale - // client during registration - Hostname string - - // Givenname represents either: - // a DNS normalized version of Hostname - // a valid name set by the User - // - // GivenName is the name used in all DNS related - // parts of headscale. - GivenName string `gorm:"type:varchar(63);unique_index"` - UserID uint - User User `gorm:"foreignKey:UserID"` - - RegisterMethod string - - ForcedTags StringList - - // TODO(kradalby): This seems like irrelevant information? - AuthKeyID uint - AuthKey *PreAuthKey - - LastSeen *time.Time - LastSuccessfulUpdate *time.Time - Expiry *time.Time - - HostInfo HostInfo - Endpoints StringList - - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time -} - -type ( - Machines []Machine - MachinesP []*Machine -) - -type MachineAddresses []netip.Addr - -func (ma MachineAddresses) ToStringSlice() []string { - strSlice := make([]string, 0, len(ma)) - for _, addr := range ma { - strSlice = append(strSlice, addr.String()) - } - - return strSlice -} - -// AppendToIPSet adds the individual ips in MachineAddresses to a -// given netipx.IPSetBuilder. -func (ma MachineAddresses) AppendToIPSet(build *netipx.IPSetBuilder) { - for _, ip := range ma { - build.Add(ip) - } -} - -func (ma *MachineAddresses) Scan(destination interface{}) error { - switch value := destination.(type) { - case string: - addresses := strings.Split(value, ",") - *ma = (*ma)[:0] - for _, addr := range addresses { - if len(addr) < 1 { - continue - } - parsed, err := netip.ParseAddr(addr) - if err != nil { - return err - } - *ma = append(*ma, parsed) - } - - return nil - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (ma MachineAddresses) Value() (driver.Value, error) { - addresses := strings.Join(ma.ToStringSlice(), ",") - - return addresses, nil -} - -// isExpired returns whether the machine registration has expired. -func (machine Machine) isExpired() bool { - // If Expiry is not set, the client has not indicated that - // it wants an expiry time, it is therefor considered - // to mean "not expired" - if machine.Expiry == nil || machine.Expiry.IsZero() { - return false - } - - return time.Now().UTC().After(*machine.Expiry) -} - -// isOnline returns if the machine is connected to Headscale. -// This is really a naive implementation, as we don't really see -// if there is a working connection between the client and the server. -func (machine *Machine) isOnline() bool { - if machine.LastSeen == nil { - return false - } - - if machine.isExpired() { - return false - } - - return machine.LastSeen.After(time.Now().Add(-keepAliveInterval)) -} - -// isEphemeral returns if the machine is registered as an Ephemeral node. -// https://tailscale.com/kb/1111/ephemeral-nodes/ -func (machine *Machine) isEphemeral() bool { - return machine.AuthKey != nil && machine.AuthKey.Ephemeral -} - -func (machine *Machine) canAccess(filter []tailcfg.FilterRule, machine2 *Machine) bool { - for _, rule := range filter { - // TODO(kradalby): Cache or pregen this - matcher := MatchFromFilterRule(rule) - - if !matcher.SrcsContainsIPs([]netip.Addr(machine.IPAddresses)) { - continue - } - - if matcher.DestsContainsIP([]netip.Addr(machine2.IPAddresses)) { - return true - } - } - - return false -} - // filterMachinesByACL wrapper function to not have devs pass around locks and maps // related to the application outside of tests. func (hsdb *HSDatabase) filterMachinesByACL( aclRules []tailcfg.FilterRule, - currentMachine *Machine, peers Machines) Machines { - return filterMachinesByACL(currentMachine, peers, aclRules) + currentMachine *types.Machine, peers types.Machines, +) types.Machines { + return policy.FilterMachinesByACL(currentMachine, peers, aclRules) } -// filterMachinesByACL returns the list of peers authorized to be accessed from a given machine. -func filterMachinesByACL( - machine *Machine, - machines Machines, - filter []tailcfg.FilterRule, -) Machines { - result := Machines{} - - for index, peer := range machines { - if peer.ID == machine.ID { - continue - } - - if machine.canAccess(filter, &machines[index]) || peer.canAccess(filter, machine) { - result = append(result, peer) - } - } - - return result -} - -func (hsdb *HSDatabase) ListPeers(machine *Machine) (Machines, error) { +func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) { log.Trace(). Caller(). Str("machine", machine.Hostname). Msg("Finding direct peers") - machines := Machines{} + machines := types.Machines{} if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?", machine.NodeKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") - return Machines{}, err + return types.Machines{}, err } sort.Slice(machines, func(i, j int) bool { return machines[i].ID < machines[j].ID }) @@ -242,22 +73,21 @@ func (hsdb *HSDatabase) ListPeers(machine *Machine) (Machines, error) { } func (hsdb *HSDatabase) getPeers( - aclPolicy *ACLPolicy, aclRules []tailcfg.FilterRule, - machine *Machine, -) (Machines, error) { - var peers Machines + machine *types.Machine, +) (types.Machines, error) { + var peers types.Machines var err error // If ACLs rules are defined, filter visible host list with the ACLs // else use the classic user scope - if aclPolicy != nil { - var machines []Machine + if len(aclRules) > 0 { + var machines []types.Machine machines, err = hsdb.ListMachines() if err != nil { log.Error().Err(err).Msg("Error retrieving list of machines") - return Machines{}, err + return types.Machines{}, err } peers = hsdb.filterMachinesByACL(aclRules, machine, machines) } else { @@ -268,7 +98,7 @@ func (hsdb *HSDatabase) getPeers( Err(err). Msg("Cannot fetch peers") - return Machines{}, err + return types.Machines{}, err } } @@ -283,20 +113,19 @@ func (hsdb *HSDatabase) getPeers( return peers, nil } -func (hsdb *HSDatabase) getValidPeers( - aclPolicy *ACLPolicy, +func (hsdb *HSDatabase) GetValidPeers( aclRules []tailcfg.FilterRule, - machine *Machine, -) (Machines, error) { - validPeers := make(Machines, 0) + machine *types.Machine, +) (types.Machines, error) { + validPeers := make(types.Machines, 0) - peers, err := hsdb.getPeers(aclPolicy, aclRules, machine) + peers, err := hsdb.getPeers(aclRules, machine) if err != nil { - return Machines{}, err + return types.Machines{}, err } for _, peer := range peers { - if !peer.isExpired() { + if !peer.IsExpired() { validPeers = append(validPeers, peer) } } @@ -304,8 +133,8 @@ func (hsdb *HSDatabase) getValidPeers( return validPeers, nil } -func (hsdb *HSDatabase) ListMachines() ([]Machine, error) { - machines := []Machine{} +func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { + machines := []types.Machine{} if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil { return nil, err } @@ -313,8 +142,8 @@ func (hsdb *HSDatabase) ListMachines() ([]Machine, error) { return machines, nil } -func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) ([]Machine, error) { - machines := []Machine{} +func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machines, error) { + machines := types.Machines{} if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil { return nil, err } @@ -323,7 +152,7 @@ func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) ([]Machine, er } // GetMachine finds a Machine by name and user and returns the Machine struct. -func (hsdb *HSDatabase) GetMachine(user string, name string) (*Machine, error) { +func (hsdb *HSDatabase) GetMachine(user string, name string) (*types.Machine, error) { machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err @@ -339,7 +168,10 @@ func (hsdb *HSDatabase) GetMachine(user string, name string) (*Machine, error) { } // GetMachineByGivenName finds a Machine by given name and user and returns the Machine struct. -func (hsdb *HSDatabase) GetMachineByGivenName(user string, givenName string) (*Machine, error) { +func (hsdb *HSDatabase) GetMachineByGivenName( + user string, + givenName string, +) (*types.Machine, error) { machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err @@ -355,9 +187,9 @@ func (hsdb *HSDatabase) GetMachineByGivenName(user string, givenName string) (*M } // GetMachineByID finds a Machine by ID and returns the Machine struct. -func (hsdb *HSDatabase) GetMachineByID(id uint64) (*Machine, error) { - m := Machine{} - if result := hsdb.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil { +func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { + m := types.Machine{} + if result := hsdb.db.Preload("AuthKey").Preload("User").Find(&types.Machine{ID: id}).First(&m); result.Error != nil { return nil, result.Error } @@ -367,8 +199,8 @@ func (hsdb *HSDatabase) GetMachineByID(id uint64) (*Machine, error) { // GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct. func (hsdb *HSDatabase) GetMachineByMachineKey( machineKey key.MachinePublic, -) (*Machine, error) { - m := Machine{} +) (*types.Machine, error) { + m := types.Machine{} if result := hsdb.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil { return nil, result.Error } @@ -379,8 +211,8 @@ func (hsdb *HSDatabase) GetMachineByMachineKey( // GetMachineByNodeKey finds a Machine by its current NodeKey. func (hsdb *HSDatabase) GetMachineByNodeKey( nodeKey key.NodePublic, -) (*Machine, error) { - machine := Machine{} +) (*types.Machine, error) { + machine := types.Machine{} if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?", util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { return nil, result.Error @@ -392,8 +224,8 @@ func (hsdb *HSDatabase) GetMachineByNodeKey( // GetMachineByAnyNodeKey finds a Machine by its MachineKey, its current NodeKey or the old one, and returns the Machine struct. func (hsdb *HSDatabase) GetMachineByAnyKey( machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, -) (*Machine, error) { - machine := Machine{} +) (*types.Machine, error) { + machine := types.Machine{} if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?", util.MachinePublicKeyStripPrefix(machineKey), util.NodePublicKeyStripPrefix(nodeKey), @@ -404,9 +236,10 @@ func (hsdb *HSDatabase) GetMachineByAnyKey( return &machine, nil } +// TODO(kradalby): rename this, it sounds like a mix of getting and setting to db // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database // and updates it with the latest data from the database. -func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *Machine) error { +func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *types.Machine) error { if result := hsdb.db.Find(machine).First(&machine); result.Error != nil { return result.Error } @@ -416,13 +249,9 @@ func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *Machine) error { // SetTags takes a Machine struct pointer and update the forced tags. func (hsdb *HSDatabase) SetTags( - machine *Machine, + machine *types.Machine, tags []string, - // TODO(kradalby): This is a temporary measure to be able to detach the - // database completely from the global h. In the future, as part of this - // reorg, the rules will be generated on a per node basis, and not be prone - // to throwing error at save. - updateACL func() error) error { +) error { newTags := []string{} for _, tag := range tags { if !util.StringOrPrefixListContains(newTags, tag) { @@ -430,10 +259,8 @@ func (hsdb *HSDatabase) SetTags( } } machine.ForcedTags = newTags - if err := updateACL(); err != nil && !errors.Is(err, errEmptyPolicy) { - return err - } + hsdb.notifyPolicyChan <- struct{}{} hsdb.notifyStateChange() if err := hsdb.db.Save(machine).Error; err != nil { @@ -444,7 +271,7 @@ func (hsdb *HSDatabase) SetTags( } // ExpireMachine takes a Machine struct and sets the expire field to now. -func (hsdb *HSDatabase) ExpireMachine(machine *Machine) error { +func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { now := time.Now() machine.Expiry = &now @@ -459,8 +286,8 @@ func (hsdb *HSDatabase) ExpireMachine(machine *Machine) error { // RenameMachine takes a Machine struct and a new GivenName for the machines // and renames it. -func (hsdb *HSDatabase) RenameMachine(machine *Machine, newName string) error { - err := CheckForFQDNRules( +func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) error { + err := util.CheckForFQDNRules( newName, ) if err != nil { @@ -484,8 +311,8 @@ func (hsdb *HSDatabase) RenameMachine(machine *Machine, newName string) error { return nil } -// RefreshMachine takes a Machine struct and sets the expire field to now. -func (hsdb *HSDatabase) RefreshMachine(machine *Machine, expiry time.Time) error { +// RefreshMachine takes a Machine struct and a new expiry time. +func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) error { now := time.Now() machine.LastSuccessfulUpdate = &now @@ -504,7 +331,7 @@ func (hsdb *HSDatabase) RefreshMachine(machine *Machine, expiry time.Time) error } // DeleteMachine softs deletes a Machine from the database. -func (hsdb *HSDatabase) DeleteMachine(machine *Machine) error { +func (hsdb *HSDatabase) DeleteMachine(machine *types.Machine) error { err := hsdb.DeleteMachineRoutes(machine) if err != nil { return err @@ -517,8 +344,8 @@ func (hsdb *HSDatabase) DeleteMachine(machine *Machine) error { return nil } -func (hsdb *HSDatabase) TouchMachine(machine *Machine) error { - return hsdb.db.Updates(Machine{ +func (hsdb *HSDatabase) TouchMachine(machine *types.Machine) error { + return hsdb.db.Updates(types.Machine{ ID: machine.ID, LastSeen: machine.LastSeen, LastSuccessfulUpdate: machine.LastSuccessfulUpdate, @@ -526,7 +353,7 @@ func (hsdb *HSDatabase) TouchMachine(machine *Machine) error { } // HardDeleteMachine hard deletes a Machine from the database. -func (hsdb *HSDatabase) HardDeleteMachine(machine *Machine) error { +func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error { err := hsdb.DeleteMachineRoutes(machine) if err != nil { return err @@ -539,12 +366,7 @@ func (hsdb *HSDatabase) HardDeleteMachine(machine *Machine) error { return nil } -// GetHostInfo returns a Hostinfo struct for the machine. -func (machine *Machine) GetHostInfo() tailcfg.Hostinfo { - return tailcfg.Hostinfo(machine.HostInfo) -} - -func (hsdb *HSDatabase) isOutdated(machine *Machine, lastChange time.Time) bool { +func (hsdb *HSDatabase) IsOutdated(machine *types.Machine, lastChange time.Time) bool { if err := hsdb.UpdateMachineFromDatabase(machine); err != nil { // It does not seem meaningful to propagate this error as the end result // will have to be that the machine has to be considered outdated. @@ -570,291 +392,13 @@ func (hsdb *HSDatabase) isOutdated(machine *Machine, lastChange time.Time) bool return lastUpdate.Before(lastChange) } -func (machine Machine) String() string { - return machine.Hostname -} - -func (machines Machines) String() string { - temp := make([]string, len(machines)) - - for index, machine := range machines { - temp[index] = machine.Hostname - } - - return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) -} - -// TODO(kradalby): Remove when we have generics... -func (machines MachinesP) String() string { - temp := make([]string, len(machines)) - - for index, machine := range machines { - temp[index] = machine.Hostname - } - - return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) -} - -func (hsdb *HSDatabase) toNodes( - machines Machines, - aclPolicy *ACLPolicy, - baseDomain string, - dnsConfig *tailcfg.DNSConfig, -) ([]*tailcfg.Node, error) { - nodes := make([]*tailcfg.Node, len(machines)) - - for index, machine := range machines { - node, err := hsdb.toNode(machine, aclPolicy, baseDomain, dnsConfig) - if err != nil { - return nil, err - } - - nodes[index] = node - } - - return nodes, nil -} - -// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes -// as per the expected behaviour in the official SaaS. -func (hsdb *HSDatabase) toNode( - machine Machine, - aclPolicy *ACLPolicy, - baseDomain string, - dnsConfig *tailcfg.DNSConfig, -) (*tailcfg.Node, error) { - var nodeKey key.NodePublic - err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey))) - if err != nil { - log.Trace(). - Caller(). - Str("node_key", machine.NodeKey). - Msgf("Failed to parse node public key from hex") - - return nil, fmt.Errorf("failed to parse node public key: %w", err) - } - - var machineKey key.MachinePublic - // MachineKey is only used in the legacy protocol - if machine.MachineKey != "" { - err = machineKey.UnmarshalText( - []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), - ) - if err != nil { - return nil, fmt.Errorf("failed to parse machine public key: %w", err) - } - } - - var discoKey key.DiscoPublic - if machine.DiscoKey != "" { - err := discoKey.UnmarshalText( - []byte(util.DiscoPublicKeyEnsurePrefix(machine.DiscoKey)), - ) - if err != nil { - return nil, fmt.Errorf("failed to parse disco public key: %w", err) - } - } else { - discoKey = key.DiscoPublic{} - } - - addrs := []netip.Prefix{} - for _, machineAddress := range machine.IPAddresses { - ip := netip.PrefixFrom(machineAddress, machineAddress.BitLen()) - addrs = append(addrs, ip) - } - - allowedIPs := append( - []netip.Prefix{}, - addrs...) // we append the node own IP, as it is required by the clients - - primaryRoutes, err := hsdb.getMachinePrimaryRoutes(&machine) - if err != nil { - return nil, err - } - primaryPrefixes := Routes(primaryRoutes).toPrefixes() - - machineRoutes, err := hsdb.GetMachineRoutes(&machine) - if err != nil { - return nil, err - } - for _, route := range machineRoutes { - if route.Enabled && (route.IsPrimary || route.isExitRoute()) { - allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix)) - } - } - - var derp string - if machine.HostInfo.NetInfo != nil { - derp = fmt.Sprintf("127.3.3.40:%d", machine.HostInfo.NetInfo.PreferredDERP) - } else { - derp = "127.3.3.40:0" // Zero means disconnected or unknown. - } - - var keyExpiry time.Time - if machine.Expiry != nil { - keyExpiry = *machine.Expiry - } else { - keyExpiry = time.Time{} - } - - var hostname string - if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS - hostname = fmt.Sprintf( - "%s.%s.%s", - machine.GivenName, - machine.User.Name, - baseDomain, - ) - if len(hostname) > maxHostnameLength { - return nil, fmt.Errorf( - "hostname %q is too long it cannot except 255 ASCII chars: %w", - hostname, - ErrHostnameTooLong, - ) - } - } else { - hostname = machine.GivenName - } - - hostInfo := machine.GetHostInfo() - - online := machine.isOnline() - - tags, _ := getTags(aclPolicy, machine, hsdb.stripEmailDomain) - tags = lo.Uniq(append(tags, machine.ForcedTags...)) - - node := tailcfg.Node{ - ID: tailcfg.NodeID(machine.ID), // this is the actual ID - StableID: tailcfg.StableNodeID( - strconv.FormatUint(machine.ID, util.Base10), - ), // in headscale, unlike tailcontrol server, IDs are permanent - Name: hostname, - - User: tailcfg.UserID(machine.UserID), - - Key: nodeKey, - KeyExpiry: keyExpiry, - - Machine: machineKey, - DiscoKey: discoKey, - Addresses: addrs, - AllowedIPs: allowedIPs, - Endpoints: machine.Endpoints, - DERP: derp, - Hostinfo: hostInfo.View(), - Created: machine.CreatedAt, - - Tags: tags, - - PrimaryRoutes: primaryPrefixes, - - LastSeen: machine.LastSeen, - Online: &online, - KeepAlive: true, - MachineAuthorized: !machine.isExpired(), - - Capabilities: []string{ - tailcfg.CapabilityFileSharing, - tailcfg.CapabilityAdmin, - tailcfg.CapabilitySSH, - }, - } - - return &node, nil -} - -func (machine *Machine) toProto() *v1.Machine { - machineProto := &v1.Machine{ - Id: machine.ID, - MachineKey: machine.MachineKey, - - NodeKey: machine.NodeKey, - DiscoKey: machine.DiscoKey, - IpAddresses: machine.IPAddresses.ToStringSlice(), - Name: machine.Hostname, - GivenName: machine.GivenName, - User: machine.User.toProto(), - ForcedTags: machine.ForcedTags, - Online: machine.isOnline(), - - // TODO(kradalby): Implement register method enum converter - // RegisterMethod: , - - CreatedAt: timestamppb.New(machine.CreatedAt), - } - - if machine.AuthKey != nil { - machineProto.PreAuthKey = machine.AuthKey.toProto() - } - - if machine.LastSeen != nil { - machineProto.LastSeen = timestamppb.New(*machine.LastSeen) - } - - if machine.LastSuccessfulUpdate != nil { - machineProto.LastSuccessfulUpdate = timestamppb.New( - *machine.LastSuccessfulUpdate, - ) - } - - if machine.Expiry != nil { - machineProto.Expiry = timestamppb.New(*machine.Expiry) - } - - return machineProto -} - -// getTags will return the tags of the current machine. -// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. -// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. -func getTags( - aclPolicy *ACLPolicy, - machine Machine, - stripEmailDomain bool, -) ([]string, []string) { - validTags := make([]string, 0) - invalidTags := make([]string, 0) - if aclPolicy == nil { - return validTags, invalidTags - } - validTagMap := make(map[string]bool) - invalidTagMap := make(map[string]bool) - for _, tag := range machine.HostInfo.RequestTags { - owners, err := getTagOwners(aclPolicy, tag, stripEmailDomain) - if errors.Is(err, errInvalidTag) { - invalidTagMap[tag] = true - - continue - } - var found bool - for _, owner := range owners { - if machine.User.Name == owner { - found = true - } - } - if found { - validTagMap[tag] = true - } else { - invalidTagMap[tag] = true - } - } - for tag := range invalidTagMap { - invalidTags = append(invalidTags, tag) - } - for tag := range validTagMap { - validTags = append(validTags, tag) - } - - return validTags, invalidTags -} - func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( cache *cache.Cache, nodeKeyStr string, userName string, machineExpiry *time.Time, registrationMethod string, -) (*Machine, error) { +) (*types.Machine, error) { nodeKey := key.NodePublic{} err := nodeKey.UnmarshalText([]byte(nodeKeyStr)) if err != nil { @@ -869,7 +413,7 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( Msg("Registering machine from API/CLI or auth callback") if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok { - if registrationMachine, ok := machineInterface.(Machine); ok { + if registrationMachine, ok := machineInterface.(types.Machine); ok { user, err := hsdb.GetUser(userName) if err != nil { return nil, fmt.Errorf( @@ -909,8 +453,8 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( } // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. -func (hsdb *HSDatabase) RegisterMachine(machine Machine, -) (*Machine, error) { +func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, +) (*types.Machine, error) { log.Debug(). Str("machine", machine.Hostname). Str("machine_key", machine.MachineKey). @@ -966,9 +510,44 @@ func (hsdb *HSDatabase) RegisterMachine(machine Machine, return &machine, nil } +// MachineSetNodeKey sets the node key of a machine and saves it to the database. +func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.NodePublic) error { + machine.NodeKey = util.NodePublicKeyStripPrefix(nodeKey) + + if err := hsdb.db.Save(machine).Error; err != nil { + return err + } + + return nil +} + +// MachineSetMachineKey sets the machine key of a machine and saves it to the database. +func (hsdb *HSDatabase) MachineSetMachineKey( + machine *types.Machine, + nodeKey key.MachinePublic, +) error { + machine.MachineKey = util.MachinePublicKeyStripPrefix(nodeKey) + + if err := hsdb.db.Save(machine).Error; err != nil { + return err + } + + return nil +} + +// MachineSave saves a machine object to the database, prefer to use a specific save method rather +// than this. It is intended to be used when we are changing or. +func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { + if err := hsdb.db.Save(machine).Error; err != nil { + return err + } + + return nil +} + // GetAdvertisedRoutes returns the routes that are be advertised by the given machine. -func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) { - routes := []Route{} +func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { + routes := types.Routes{} err := hsdb.db. Preload("Machine"). @@ -992,8 +571,8 @@ func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, e } // GetEnabledRoutes returns the routes that are enabled for the machine. -func (hsdb *HSDatabase) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) { - routes := []Route{} +func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { + routes := types.Routes{} err := hsdb.db. Preload("Machine"). @@ -1017,7 +596,7 @@ func (hsdb *HSDatabase) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, erro return prefixes, nil } -func (hsdb *HSDatabase) IsRoutesEnabled(machine *Machine, routeStr string) bool { +func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) bool { route, err := netip.ParsePrefix(routeStr) if err != nil { return false @@ -1040,7 +619,7 @@ func (hsdb *HSDatabase) IsRoutesEnabled(machine *Machine, routeStr string) bool } // enableRoutes enables new routes based on a list of new routes. -func (hsdb *HSDatabase) enableRoutes(machine *Machine, routeStrs ...string) error { +func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error { newRoutes := make([]netip.Prefix, len(routeStrs)) for index, routeStr := range routeStrs { route, err := netip.ParsePrefix(routeStr) @@ -1068,16 +647,16 @@ func (hsdb *HSDatabase) enableRoutes(machine *Machine, routeStrs ...string) erro // Separate loop so we don't leave things in a half-updated state for _, prefix := range newRoutes { - route := Route{} + route := types.Route{} err := hsdb.db.Preload("Machine"). - Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). + Where("machine_id = ? AND prefix = ?", machine.ID, types.IPPrefix(prefix)). First(&route).Error if err == nil { route.Enabled = true // Mark already as primary if there is only this node offering this subnet // (and is not an exit route) - if !route.isExitRoute() { + if !route.IsExitRoute() { route.IsPrimary = hsdb.isUniquePrefix(route) } @@ -1095,81 +674,8 @@ func (hsdb *HSDatabase) enableRoutes(machine *Machine, routeStrs ...string) erro return nil } -// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy. -func (hsdb *HSDatabase) EnableAutoApprovedRoutes(aclPolicy *ACLPolicy, machine *Machine) error { - if len(machine.IPAddresses) == 0 { - return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs - } - - routes := []Route{} - err := hsdb.db. - Preload("Machine"). - Where("machine_id = ? AND advertised = true AND enabled = false", machine.ID). - Find(&routes).Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - log.Error(). - Caller(). - Err(err). - Str("machine", machine.Hostname). - Msg("Could not get advertised routes for machine") - - return err - } - - approvedRoutes := []Route{} - - for _, advertisedRoute := range routes { - routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( - netip.Prefix(advertisedRoute.Prefix), - ) - if err != nil { - log.Err(err). - Str("advertisedRoute", advertisedRoute.String()). - Uint64("machineId", machine.ID). - Msg("Failed to resolve autoApprovers for advertised route") - - return err - } - - for _, approvedAlias := range routeApprovers { - if approvedAlias == machine.User.Name { - approvedRoutes = append(approvedRoutes, advertisedRoute) - } else { - approvedIps, err := aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, hsdb.stripEmailDomain) - if err != nil { - log.Err(err). - Str("alias", approvedAlias). - Msg("Failed to expand alias when processing autoApprovers policy") - - return err - } - - // approvedIPs should contain all of machine's IPs if it matches the rule, so check for first - if approvedIps.Contains(machine.IPAddresses[0]) { - approvedRoutes = append(approvedRoutes, advertisedRoute) - } - } - } - } - - for i, approvedRoute := range approvedRoutes { - approvedRoutes[i].Enabled = true - err = hsdb.db.Save(&approvedRoutes[i]).Error - if err != nil { - log.Err(err). - Str("approvedRoute", approvedRoute.String()). - Uint64("machineId", machine.ID). - Msg("Failed to enable approved route") - - return err - } - } - - return nil -} - func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { - normalizedHostname, err := NormalizeToFQDNRules( + normalizedHostname, err := util.NormalizeToFQDNRules( suppliedName, hsdb.stripEmailDomain, ) @@ -1179,7 +685,7 @@ func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool if randomSuffix { // Trim if a hostname will be longer than 63 chars after adding the hash. - trimmedHostnameLength := labelHostnameLength - MachineGivenNameHashLength - MachineGivenNameTrimSize + trimmedHostnameLength := util.LabelHostnameLength - MachineGivenNameHashLength - MachineGivenNameTrimSize if len(normalizedHostname) > trimmedHostnameLength { normalizedHostname = normalizedHostname[:trimmedHostnameLength] } @@ -1221,16 +727,260 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string return givenName, nil } -func (machines Machines) FilterByIP(ip netip.Addr) Machines { - found := make(Machines, 0) +func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Duration) { + users, err := hsdb.ListUsers() + if err != nil { + log.Error().Err(err).Msg("Error listing users") - for _, machine := range machines { - for _, mIP := range machine.IPAddresses { - if ip == mIP { - found = append(found, machine) + return + } + + for _, user := range users { + machines, err := hsdb.ListMachinesByUser(user.Name) + if err != nil { + log.Error(). + Err(err). + Str("user", user.Name). + Msg("Error listing machines in user") + + return + } + + expiredFound := false + for idx, machine := range machines { + if machine.IsEphemeral() && machine.LastSeen != nil && + time.Now(). + After(machine.LastSeen.Add(inactivityThreshhold)) { + expiredFound = true + log.Info(). + Str("machine", machine.Hostname). + Msg("Ephemeral client removed from database") + + err = hsdb.HardDeleteMachine(&machines[idx]) + if err != nil { + log.Error(). + Err(err). + Str("machine", machine.Hostname). + Msg("🤮 Cannot delete ephemeral machine from the database") + } } } + + if expiredFound { + hsdb.notifyStateChange() + } + } +} + +func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) { + users, err := hsdb.ListUsers() + if err != nil { + log.Error().Err(err).Msg("Error listing users") + + return + } + + for _, user := range users { + machines, err := hsdb.ListMachinesByUser(user.Name) + if err != nil { + log.Error(). + Err(err). + Str("user", user.Name). + Msg("Error listing machines in user") + + return + } + + expiredFound := false + for index, machine := range machines { + if machine.IsExpired() && + machine.Expiry.After(lastChange) { + expiredFound = true + + err := hsdb.ExpireMachine(&machines[index]) + if err != nil { + log.Error(). + Err(err). + Str("machine", machine.Hostname). + Str("name", machine.GivenName). + Msg("🤮 Cannot expire machine") + } else { + log.Info(). + Str("machine", machine.Hostname). + Str("name", machine.GivenName). + Msg("Machine successfully expired") + } + } + } + + if expiredFound { + hsdb.notifyStateChange() + } + } +} + +func (hsdb *HSDatabase) TailNodes( + machines types.Machines, + pol *policy.ACLPolicy, + dnsConfig *tailcfg.DNSConfig, +) ([]*tailcfg.Node, error) { + nodes := make([]*tailcfg.Node, len(machines)) + + for index, machine := range machines { + node, err := hsdb.TailNode(machine, pol, dnsConfig) + if err != nil { + return nil, err + } + + nodes[index] = node + } + + return nodes, nil +} + +// TailNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes +// as per the expected behaviour in the official SaaS. +func (hsdb *HSDatabase) TailNode( + machine types.Machine, + pol *policy.ACLPolicy, + dnsConfig *tailcfg.DNSConfig, +) (*tailcfg.Node, error) { + var nodeKey key.NodePublic + err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey))) + if err != nil { + log.Trace(). + Caller(). + Str("node_key", machine.NodeKey). + Msgf("Failed to parse node public key from hex") + + return nil, fmt.Errorf("failed to parse node public key: %w", err) + } + + var machineKey key.MachinePublic + // MachineKey is only used in the legacy protocol + if machine.MachineKey != "" { + err = machineKey.UnmarshalText( + []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), + ) + if err != nil { + return nil, fmt.Errorf("failed to parse machine public key: %w", err) + } } - return found + var discoKey key.DiscoPublic + if machine.DiscoKey != "" { + err := discoKey.UnmarshalText( + []byte(util.DiscoPublicKeyEnsurePrefix(machine.DiscoKey)), + ) + if err != nil { + return nil, fmt.Errorf("failed to parse disco public key: %w", err) + } + } else { + discoKey = key.DiscoPublic{} + } + + addrs := []netip.Prefix{} + for _, machineAddress := range machine.IPAddresses { + ip := netip.PrefixFrom(machineAddress, machineAddress.BitLen()) + addrs = append(addrs, ip) + } + + allowedIPs := append( + []netip.Prefix{}, + addrs...) // we append the node own IP, as it is required by the clients + + primaryRoutes, err := hsdb.GetMachinePrimaryRoutes(&machine) + if err != nil { + return nil, err + } + primaryPrefixes := primaryRoutes.Prefixes() + + machineRoutes, err := hsdb.GetMachineRoutes(&machine) + if err != nil { + return nil, err + } + for _, route := range machineRoutes { + if route.Enabled && (route.IsPrimary || route.IsExitRoute()) { + allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix)) + } + } + + var derp string + if machine.HostInfo.NetInfo != nil { + derp = fmt.Sprintf("127.3.3.40:%d", machine.HostInfo.NetInfo.PreferredDERP) + } else { + derp = "127.3.3.40:0" // Zero means disconnected or unknown. + } + + var keyExpiry time.Time + if machine.Expiry != nil { + keyExpiry = *machine.Expiry + } else { + keyExpiry = time.Time{} + } + + var hostname string + if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS + hostname = fmt.Sprintf( + "%s.%s.%s", + machine.GivenName, + machine.User.Name, + hsdb.baseDomain, + ) + if len(hostname) > MaxHostnameLength { + return nil, fmt.Errorf( + "hostname %q is too long it cannot except 255 ASCII chars: %w", + hostname, + ErrHostnameTooLong, + ) + } + } else { + hostname = machine.GivenName + } + + hostInfo := machine.GetHostInfo() + + online := machine.IsOnline() + + tags, _ := pol.GetTagsOfMachine(machine, hsdb.stripEmailDomain) + tags = lo.Uniq(append(tags, machine.ForcedTags...)) + + node := tailcfg.Node{ + ID: tailcfg.NodeID(machine.ID), // this is the actual ID + StableID: tailcfg.StableNodeID( + strconv.FormatUint(machine.ID, util.Base10), + ), // in headscale, unlike tailcontrol server, IDs are permanent + Name: hostname, + + User: tailcfg.UserID(machine.UserID), + + Key: nodeKey, + KeyExpiry: keyExpiry, + + Machine: machineKey, + DiscoKey: discoKey, + Addresses: addrs, + AllowedIPs: allowedIPs, + Endpoints: machine.Endpoints, + DERP: derp, + Hostinfo: hostInfo.View(), + Created: machine.CreatedAt, + + Tags: tags, + + PrimaryRoutes: primaryPrefixes, + + LastSeen: machine.LastSeen, + Online: &online, + KeepAlive: true, + MachineAuthorized: !machine.IsExpired(), + + Capabilities: []string{ + tailcfg.CapabilityFileSharing, + tailcfg.CapabilityAdmin, + tailcfg.CapabilitySSH, + }, + } + + return &node, nil } diff --git a/hscontrol/db/machine_test.go b/hscontrol/db/machine_test.go new file mode 100644 index 0000000..f34f64d --- /dev/null +++ b/hscontrol/db/machine_test.go @@ -0,0 +1,797 @@ +package db + +import ( + "fmt" + "net/netip" + "regexp" + "strconv" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "gopkg.in/check.v1" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func (s *Suite) TestGetMachine(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + machine := &types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(machine) + + _, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestGetMachineByID(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachineByID(0) + c.Assert(err, check.NotNil) + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + + _, err = db.GetMachineByID(0) + c.Assert(err, check.IsNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestGetMachineByNodeKey(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachineByID(0) + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + machine := types.Machine{ + ID: 0, + MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + + _, err = db.GetMachineByNodeKey(nodeKey.Public()) + c.Assert(err, check.IsNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachineByID(0) + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + oldNodeKey := key.NewNode() + + machineKey := key.NewMachine() + + machine := types.Machine{ + ID: 0, + MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + + _, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) + c.Assert(err, check.IsNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestDeleteMachine(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(1), + } + db.db.Save(&machine) + + err = db.DeleteMachine(&machine) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine(user.Name, "testmachine") + c.Assert(err, check.NotNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestHardDeleteMachine(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine3", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(1), + } + db.db.Save(&machine) + + err = db.HardDeleteMachine(&machine) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine(user.Name, "testmachine3") + c.Assert(err, check.NotNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestListPeers(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachineByID(0) + c.Assert(err, check.NotNil) + + for index := 0; index <= 10; index++ { + machine := types.Machine{ + ID: uint64(index), + MachineKey: "foo" + strconv.Itoa(index), + NodeKey: "bar" + strconv.Itoa(index), + DiscoKey: "faa" + strconv.Itoa(index), + Hostname: "testmachine" + strconv.Itoa(index), + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + } + + machine0ByID, err := db.GetMachineByID(0) + c.Assert(err, check.IsNil) + + peersOfMachine0, err := db.ListPeers(machine0ByID) + c.Assert(err, check.IsNil) + + c.Assert(len(peersOfMachine0), check.Equals, 9) + c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2") + c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7") + c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10") + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestGetACLFilteredPeers(c *check.C) { + type base struct { + user *types.User + key *types.PreAuthKey + } + + stor := make([]base, 0) + + for _, name := range []string{"test", "admin"} { + user, err := db.CreateUser(name) + c.Assert(err, check.IsNil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + stor = append(stor, base{user, pak}) + } + + _, err := db.GetMachineByID(0) + c.Assert(err, check.NotNil) + + for index := 0; index <= 10; index++ { + machine := types.Machine{ + ID: uint64(index), + MachineKey: "foo" + strconv.Itoa(index), + NodeKey: "bar" + strconv.Itoa(index), + DiscoKey: "faa" + strconv.Itoa(index), + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))), + }, + Hostname: "testmachine" + strconv.Itoa(index), + UserID: stor[index%2].user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(stor[index%2].key.ID), + } + db.db.Save(&machine) + } + + aclPolicy := &policy.ACLPolicy{ + Groups: map[string][]string{ + "group:test": {"admin"}, + }, + Hosts: map[string]netip.Prefix{}, + TagOwners: map[string][]string{}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"admin"}, + Destinations: []string{"*:*"}, + }, + { + Action: "accept", + Sources: []string{"test"}, + Destinations: []string{"test:*"}, + }, + }, + Tests: []policy.ACLTest{}, + } + + adminMachine, err := db.GetMachineByID(1) + c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User) + c.Assert(err, check.IsNil) + + testMachine, err := db.GetMachineByID(2) + c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User) + c.Assert(err, check.IsNil) + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false) + c.Assert(err, check.IsNil) + + peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines) + peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines) + + c.Log(peersOfTestMachine) + c.Assert(len(peersOfTestMachine), check.Equals, 9) + c.Assert(peersOfTestMachine[0].Hostname, check.Equals, "testmachine1") + c.Assert(peersOfTestMachine[1].Hostname, check.Equals, "testmachine3") + c.Assert(peersOfTestMachine[3].Hostname, check.Equals, "testmachine5") + + c.Log(peersOfAdminMachine) + c.Assert(len(peersOfAdminMachine), check.Equals, 9) + c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2") + c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4") + c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7") + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestExpireMachine(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + machine := &types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + Expiry: &time.Time{}, + } + db.db.Save(machine) + + machineFromDB, err := db.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert(machineFromDB, check.NotNil) + + c.Assert(machineFromDB.IsExpired(), check.Equals, false) + + err = db.ExpireMachine(machineFromDB) + c.Assert(err, check.IsNil) + + c.Assert(machineFromDB.IsExpired(), check.Equals, true) + + c.Assert(channelUpdates, check.Equals, int32(1)) +} + +func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { + input := types.MachineAddresses([]netip.Addr{ + netip.MustParseAddr("192.0.2.1"), + netip.MustParseAddr("2001:db8::1"), + }) + serialized, err := input.Value() + c.Assert(err, check.IsNil) + if serial, ok := serialized.(string); ok { + c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1") + } + + var deserialized types.MachineAddresses + err = deserialized.Scan(serialized) + c.Assert(err, check.IsNil) + + c.Assert(len(deserialized), check.Equals, len(input)) + for i := range deserialized { + c.Assert(deserialized[i], check.Equals, input[i]) + } + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestGenerateGivenName(c *check.C) { + user1, err := db.CreateUser("user-1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user-1", "testmachine") + c.Assert(err, check.NotNil) + + machine := &types.Machine{ + ID: 0, + MachineKey: "machine-key-1", + NodeKey: "node-key-1", + DiscoKey: "disco-key-1", + Hostname: "hostname-1", + GivenName: "hostname-1", + UserID: user1.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(machine) + + givenName, err := db.GenerateGivenName("machine-key-2", "hostname-2") + comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict") + c.Assert(err, check.IsNil, comment) + c.Assert(givenName, check.Equals, "hostname-2", comment) + + givenName, err = db.GenerateGivenName("machine-key-1", "hostname-1") + comment = check.Commentf("Same user, same machine, same hostname, no conflict") + c.Assert(err, check.IsNil, comment) + c.Assert(givenName, check.Equals, "hostname-1", comment) + + givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1") + comment = check.Commentf("Same user, unique machines, same hostname, conflict") + c.Assert(err, check.IsNil, comment) + c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) + + givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1") + comment = check.Commentf("Unique users, unique machines, same hostname, conflict") + c.Assert(err, check.IsNil, comment) + c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestSetTags(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + machine := &types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(machine) + + // assign simple tags + sTags := []string{"tag:test", "tag:foo"} + err = db.SetTags(machine, sTags) + c.Assert(err, check.IsNil) + machine, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert(machine.ForcedTags, check.DeepEquals, types.StringList(sTags)) + + // assign duplicat tags, expect no errors but no doubles in DB + eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} + err = db.SetTags(machine, eTags) + c.Assert(err, check.IsNil) + machine, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert( + machine.ForcedTags, + check.DeepEquals, + types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), + ) + + c.Assert(channelUpdates, check.Equals, int32(4)) +} + +func TestHeadscale_generateGivenName(t *testing.T) { + type args struct { + suppliedName string + randomSuffix bool + } + tests := []struct { + name string + db *HSDatabase + args args + want *regexp.Regexp + wantErr bool + }{ + { + name: "simple machine name generation", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "testmachine", + randomSuffix: false, + }, + want: regexp.MustCompile("^testmachine$"), + wantErr: false, + }, + { + name: "machine name with 53 chars", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", + randomSuffix: false, + }, + want: regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"), + wantErr: false, + }, + { + name: "machine name with 63 chars", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", + randomSuffix: false, + }, + want: regexp.MustCompile("^machineeee12345678901234567890123456789012345678901234567890123$"), + wantErr: false, + }, + { + name: "machine name with 64 chars", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", + randomSuffix: false, + }, + want: nil, + wantErr: true, + }, + { + name: "machine name with 73 chars", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", + randomSuffix: false, + }, + want: nil, + wantErr: true, + }, + { + name: "machine name with random suffix", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "test", + randomSuffix: true, + }, + want: regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", MachineGivenNameHashLength)), + wantErr: false, + }, + { + name: "machine name with 63 chars with random suffix", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", + randomSuffix: true, + }, + want: regexp.MustCompile(fmt.Sprintf("^machineeee1234567890123456789012345678901234567890123-[a-z0-9]{%d}$", MachineGivenNameHashLength)), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) + if (err != nil) != tt.wantErr { + t.Errorf( + "Headscale.GenerateGivenName() error = %v, wantErr %v", + err, + tt.wantErr, + ) + + return + } + + if tt.want != nil && !tt.want.MatchString(got) { + t.Errorf( + "Headscale.GenerateGivenName() = %v, does not match %v", + tt.want, + got, + ) + } + + if len(got) > util.LabelHostnameLength { + t.Errorf( + "Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d", + got, + util.LabelHostnameLength, + ) + } + }) + } +} + +func (s *Suite) TestAutoApproveRoutes(c *check.C) { + acl := []byte(` +{ + "tagOwners": { + "tag:exit": ["test"], + }, + + "groups": { + "group:test": ["test"] + }, + + "acls": [ + {"action": "accept", "users": ["*"], "ports": ["*:*"]}, + ], + + "autoApprovers": { + "exitNode": ["tag:exit"], + "routes": { + "10.10.0.0/16": ["group:test"], + "10.11.0.0/16": ["test"], + } + } +} + `) + + pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") + c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) + + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + nodeKey := key.NewNode() + + defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0") + defaultRouteV6 := netip.MustParsePrefix("::/0") + route1 := netip.MustParsePrefix("10.10.0.0/16") + // Check if a subprefix of an autoapproved route is approved + route2 := netip.MustParsePrefix("10.11.0.0/24") + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), + DiscoKey: "faa", + Hostname: "test", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo{ + RequestTags: []string{"tag:exit"}, + RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, + }, + IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, + } + + db.db.Save(&machine) + + err = db.ProcessMachineRoutes(&machine) + c.Assert(err, check.IsNil) + + machine0ByID, err := db.GetMachineByID(0) + c.Assert(err, check.IsNil) + + err = db.EnableAutoApprovedRoutes(pol, machine0ByID) + c.Assert(err, check.IsNil) + + enabledRoutes, err := db.GetEnabledRoutes(machine0ByID) + c.Assert(err, check.IsNil) + c.Assert(enabledRoutes, check.HasLen, 4) + + c.Assert(channelUpdates, check.Equals, int32(4)) +} + +func TestMachine_canAccess(t *testing.T) { + type args struct { + filter []tailcfg.FilterRule + machine2 *types.Machine + } + tests := []struct { + name string + machine types.Machine + args args + want bool + }{ + { + name: "no-rules", + machine: types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.1"), + }, + }, + args: args{ + filter: []tailcfg.FilterRule{}, + machine2: &types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.2"), + }, + }, + }, + want: false, + }, + { + name: "wildcard", + machine: types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.1"), + }, + }, + args: args{ + filter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "*", + Ports: tailcfg.PortRange{ + First: 0, + Last: 65535, + }, + }, + }, + }, + }, + machine2: &types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.2"), + }, + }, + }, + want: true, + }, + { + name: "explicit-m1-to-m2", + machine: types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.1"), + }, + }, + args: args{ + filter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"10.0.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "10.0.0.2", + Ports: tailcfg.PortRange{ + First: 0, + Last: 65535, + }, + }, + }, + }, + }, + machine2: &types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.2"), + }, + }, + }, + want: true, + }, + { + name: "explicit-m2-to-m1", + machine: types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.1"), + }, + }, + args: args{ + filter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"10.0.0.2"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "10.0.0.1", + Ports: tailcfg.PortRange{ + First: 0, + Last: 65535, + }, + }, + }, + }, + }, + machine2: &types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.2"), + }, + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.machine.CanAccess(tt.args.filter, tt.args.machine2); got != tt.want { + t.Errorf("Machine.CanAccess() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/hscontrol/preauth_keys.go b/hscontrol/db/preauth_keys.go similarity index 59% rename from hscontrol/preauth_keys.go rename to hscontrol/db/preauth_keys.go index 1956762..abb79c3 100644 --- a/hscontrol/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -1,17 +1,14 @@ -package hscontrol +package db import ( "crypto/rand" "encoding/hex" "errors" "fmt" - "strconv" "strings" "time" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol/util" - "google.golang.org/protobuf/types/known/timestamppb" + "github.com/juanfont/headscale/hscontrol/types" "gorm.io/gorm" ) @@ -23,28 +20,6 @@ var ( ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") ) -// PreAuthKey describes a pre-authorization key usable in a particular user. -type PreAuthKey struct { - ID uint64 `gorm:"primary_key"` - Key string - UserID uint - User User - Reusable bool - Ephemeral bool `gorm:"default:false"` - Used bool `gorm:"default:false"` - ACLTags []PreAuthKeyACLTag - - CreatedAt *time.Time - Expiration *time.Time -} - -// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey. -type PreAuthKeyACLTag struct { - ID uint64 `gorm:"primary_key"` - PreAuthKeyID uint64 - Tag string -} - // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. func (hsdb *HSDatabase) CreatePreAuthKey( userName string, @@ -52,7 +27,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( ephemeral bool, expiration *time.Time, aclTags []string, -) (*PreAuthKey, error) { +) (*types.PreAuthKey, error) { user, err := hsdb.GetUser(userName) if err != nil { return nil, err @@ -74,7 +49,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( return nil, err } - key := PreAuthKey{ + key := types.PreAuthKey{ Key: kstr, UserID: user.ID, User: *user, @@ -94,7 +69,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( for _, tag := range aclTags { if !seenTags[tag] { - if err := db.Save(&PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { + if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { return fmt.Errorf( "failed to ceate key tag in the database: %w", err, @@ -116,14 +91,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey( } // ListPreAuthKeys returns the list of PreAuthKeys for a user. -func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { +func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { user, err := hsdb.GetUser(userName) if err != nil { return nil, err } - keys := []PreAuthKey{} - if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { + keys := []types.PreAuthKey{} + if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { return nil, err } @@ -131,8 +106,8 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { } // GetPreAuthKey returns a PreAuthKey for a given key. -func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, error) { - pak, err := hsdb.checkKeyValidity(key) +func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { + pak, err := hsdb.ValidatePreAuthKey(key) if err != nil { return nil, err } @@ -146,9 +121,9 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, err // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. -func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error { +func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { return hsdb.db.Transaction(func(db *gorm.DB) error { - if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil { + if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { return result.Error } @@ -161,7 +136,7 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error { } // MarkExpirePreAuthKey marks a PreAuthKey as expired. -func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error { +func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -170,7 +145,7 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error { } // UsePreAuthKey marks a PreAuthKey as used. -func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error { +func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { k.Used = true if err := hsdb.db.Save(k).Error; err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) @@ -179,10 +154,10 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error { return nil } -// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node +// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. -func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) { - pak := PreAuthKey{} +func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { + pak := types.PreAuthKey{} if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( result.Error, gorm.ErrRecordNotFound, @@ -198,8 +173,8 @@ func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) { return &pak, nil } - machines := []Machine{} - if err := hsdb.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { + machines := types.Machines{} + if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { return nil, err } @@ -219,29 +194,3 @@ func (hsdb *HSDatabase) generateKey() (string, error) { return hex.EncodeToString(bytes), nil } - -func (key *PreAuthKey) toProto() *v1.PreAuthKey { - protoKey := v1.PreAuthKey{ - User: key.User.Name, - Id: strconv.FormatUint(key.ID, util.Base10), - Key: key.Key, - Ephemeral: key.Ephemeral, - Reusable: key.Reusable, - Used: key.Used, - AclTags: make([]string, len(key.ACLTags)), - } - - if key.Expiration != nil { - protoKey.Expiration = timestamppb.New(*key.Expiration) - } - - if key.CreatedAt != nil { - protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) - } - - for idx := range key.ACLTags { - protoKey.AclTags[idx] = key.ACLTags[idx].Tag - } - - return &protoKey -} diff --git a/hscontrol/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go similarity index 57% rename from hscontrol/preauth_keys_test.go rename to hscontrol/db/preauth_keys_test.go index a85a6c6..e4a9773 100644 --- a/hscontrol/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -1,20 +1,22 @@ -package hscontrol +package db import ( "time" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" ) func (*Suite) TestCreatePreAuthKey(c *check.C) { - _, err := app.db.CreatePreAuthKey("bogus", true, false, nil, nil) + _, err := db.CreatePreAuthKey("bogus", true, false, nil, nil) c.Assert(err, check.NotNil) - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - key, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) + key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) // Did we get a valid key? @@ -24,10 +26,10 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { // Make sure the User association is populated c.Assert(key.User.Name, check.Equals, user.Name) - _, err = app.db.ListPreAuthKeys("bogus") + _, err = db.ListPreAuthKeys("bogus") c.Assert(err, check.NotNil) - keys, err := app.db.ListPreAuthKeys(user.Name) + keys, err := db.ListPreAuthKeys(user.Name) c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) @@ -36,174 +38,176 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { } func (*Suite) TestExpiredPreAuthKey(c *check.C) { - user, err := app.db.CreateUser("test2") + user, err := db.CreateUser("test2") c.Assert(err, check.IsNil) now := time.Now() - pak, err := app.db.CreatePreAuthKey(user.Name, true, false, &now, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) c.Assert(err, check.IsNil) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrPreAuthKeyExpired) c.Assert(key, check.IsNil) } func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { - key, err := app.db.checkKeyValidity("potatoKey") + key, err := db.ValidatePreAuthKey("potatoKey") c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) c.Assert(key, check.IsNil) } func (*Suite) TestValidateKeyOk(c *check.C) { - user, err := app.db.CreateUser("test3") + user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestAlreadyUsedKey(c *check.C) { - user, err := app.db.CreateUser("test4") + user, err := db.CreateUser("test4") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - machine := Machine{ + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testest", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.db.Save(&machine) + db.db.Save(&machine) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) c.Assert(key, check.IsNil) } func (*Suite) TestReusableBeingUsedKey(c *check.C) { - user, err := app.db.CreateUser("test5") + user, err := db.CreateUser("test5") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) - machine := Machine{ + machine := types.Machine{ ID: 1, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testest", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.db.Save(&machine) + db.db.Save(&machine) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { - user, err := app.db.CreateUser("test6") + user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestEphemeralKey(c *check.C) { - user, err := app.db.CreateUser("test7") + user, err := db.CreateUser("test7") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, true, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) c.Assert(err, check.IsNil) - now := time.Now() - machine := Machine{ + now := time.Now().Add(-time.Second * 30) + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testest", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, LastSeen: &now, AuthKeyID: uint(pak.ID), } - app.db.db.Save(&machine) + db.db.Save(&machine) - _, err = app.db.checkKeyValidity(pak.Key) + _, err = db.ValidatePreAuthKey(pak.Key) // Ephemeral keys are by definition reusable c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test7", "testest") + _, err = db.GetMachine("test7", "testest") c.Assert(err, check.IsNil) - app.expireEphemeralNodesWorker() + db.ExpireEphemeralMachines(time.Second * 20) // The machine record should have been deleted - _, err = app.db.GetMachine("test7", "testest") + _, err = db.GetMachine("test7", "testest") c.Assert(err, check.NotNil) + + c.Assert(channelUpdates, check.Equals, int32(1)) } func (*Suite) TestExpirePreauthKey(c *check.C) { - user, err := app.db.CreateUser("test3") + user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.IsNil) - err = app.db.ExpirePreAuthKey(pak) + err = db.ExpirePreAuthKey(pak) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.NotNil) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrPreAuthKeyExpired) c.Assert(key, check.IsNil) } func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { - user, err := app.db.CreateUser("test6") + user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true - app.db.db.Save(&pak) + db.db.Save(&pak) - _, err = app.db.checkKeyValidity(pak.Key) + _, err = db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) } func (*Suite) TestPreAuthKeyACLTags(c *check.C) { - user, err := app.db.CreateUser("test8") + user, err := db.CreateUser("test8") c.Assert(err, check.IsNil) - _, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) + _, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected tags := []string{"tag:test1", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) + _, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) c.Assert(err, check.IsNil) - listedPaks, err := app.db.ListPreAuthKeys("test8") + listedPaks, err := db.ListPreAuthKeys("test8") c.Assert(err, check.IsNil) - c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags) + c.Assert(listedPaks[0].Proto().AclTags, check.DeepEquals, tags) } diff --git a/hscontrol/routes.go b/hscontrol/db/routes.go similarity index 62% rename from hscontrol/routes.go rename to hscontrol/db/routes.go index e3be2f6..bdb3f4c 100644 --- a/hscontrol/routes.go +++ b/hscontrol/db/routes.go @@ -1,55 +1,19 @@ -package hscontrol +package db import ( "errors" - "fmt" "net/netip" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" - "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" ) -var ( - ErrRouteIsNotAvailable = errors.New("route is not available") - ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") - ExitRouteV6 = netip.MustParsePrefix("::/0") -) +var ErrRouteIsNotAvailable = errors.New("route is not available") -type Route struct { - gorm.Model - - MachineID uint64 - Machine Machine - Prefix IPPrefix - - Advertised bool - Enabled bool - IsPrimary bool -} - -type Routes []Route - -func (r *Route) String() string { - return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) -} - -func (r *Route) isExitRoute() bool { - return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 -} - -func (rs Routes) toPrefixes() []netip.Prefix { - prefixes := make([]netip.Prefix, len(rs)) - for i, r := range rs { - prefixes[i] = netip.Prefix(r.Prefix) - } - - return prefixes -} - -func (hsdb *HSDatabase) GetRoutes() ([]Route, error) { - var routes []Route +func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { + var routes types.Routes err := hsdb.db.Preload("Machine").Find(&routes).Error if err != nil { return nil, err @@ -58,8 +22,21 @@ func (hsdb *HSDatabase) GetRoutes() ([]Route, error) { return routes, nil } -func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) { - var routes []Route +func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { + var routes types.Routes + err := hsdb.db. + Preload("Machine"). + Where("machine_id = ? AND advertised = true", machine.ID). + Find(&routes).Error + if err != nil { + return nil, err + } + + return routes, nil +} + +func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) { + var routes types.Routes err := hsdb.db. Preload("Machine"). Where("machine_id = ?", m.ID). @@ -71,8 +48,8 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) { return routes, nil } -func (hsdb *HSDatabase) GetRoute(id uint64) (*Route, error) { - var route Route +func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { + var route types.Route err := hsdb.db.Preload("Machine").First(&route, id).Error if err != nil { return nil, err @@ -90,8 +67,12 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 - if route.isExitRoute() { - return hsdb.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String()) + if route.IsExitRoute() { + return hsdb.enableRoutes( + &route.Machine, + types.ExitRouteV4.String(), + types.ExitRouteV6.String(), + ) } return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String()) @@ -106,7 +87,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 - if !route.isExitRoute() { + if !route.IsExitRoute() { route.Enabled = false route.IsPrimary = false err = hsdb.db.Save(route).Error @@ -114,7 +95,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { return err } - return hsdb.handlePrimarySubnetFailover() + return hsdb.HandlePrimarySubnetFailover() } routes, err := hsdb.GetMachineRoutes(&route.Machine) @@ -123,7 +104,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { } for i := range routes { - if routes[i].isExitRoute() { + if routes[i].IsExitRoute() { routes[i].Enabled = false routes[i].IsPrimary = false err = hsdb.db.Save(&routes[i]).Error @@ -133,7 +114,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { } } - return hsdb.handlePrimarySubnetFailover() + return hsdb.HandlePrimarySubnetFailover() } func (hsdb *HSDatabase) DeleteRoute(id uint64) error { @@ -145,12 +126,12 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 - if !route.isExitRoute() { + if !route.IsExitRoute() { if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { return err } - return hsdb.handlePrimarySubnetFailover() + return hsdb.HandlePrimarySubnetFailover() } routes, err := hsdb.GetMachineRoutes(&route.Machine) @@ -158,9 +139,9 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { return err } - routesToDelete := []Route{} + routesToDelete := types.Routes{} for _, r := range routes { - if r.isExitRoute() { + if r.IsExitRoute() { routesToDelete = append(routesToDelete, r) } } @@ -169,10 +150,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { return err } - return hsdb.handlePrimarySubnetFailover() + return hsdb.HandlePrimarySubnetFailover() } -func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error { +func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { routes, err := hsdb.GetMachineRoutes(m) if err != nil { return err @@ -184,14 +165,14 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error { } } - return hsdb.handlePrimarySubnetFailover() + return hsdb.HandlePrimarySubnetFailover() } // isUniquePrefix returns if there is another machine providing the same route already. -func (hsdb *HSDatabase) isUniquePrefix(route Route) bool { +func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { var count int64 hsdb.db. - Model(&Route{}). + Model(&types.Route{}). Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", route.Prefix, route.MachineID, @@ -200,11 +181,11 @@ func (hsdb *HSDatabase) isUniquePrefix(route Route) bool { return count == 0 } -func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { - var route Route +func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) { + var route types.Route err := hsdb.db. Preload("Machine"). - Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true). + Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true). First(&route).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err @@ -219,8 +200,8 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary. -func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { - var routes []Route +func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) { + var routes types.Routes err := hsdb.db. Preload("Machine"). Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true). @@ -232,8 +213,8 @@ func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { return routes, nil } -func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { - currentRoutes := []Route{} +func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { + currentRoutes := types.Routes{} err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error if err != nil { return err @@ -266,9 +247,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { for prefix, exists := range advertisedRoutes { if !exists { - route := Route{ + route := types.Route{ MachineID: machine.ID, - Prefix: IPPrefix(prefix), + Prefix: types.IPPrefix(prefix), Advertised: true, Enabled: false, } @@ -282,9 +263,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { return nil } -func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { +func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { // first, get all the enabled routes - var routes []Route + var routes types.Routes err := hsdb.db. Preload("Machine"). Where("advertised = ? AND enabled = ?", true, true). @@ -295,7 +276,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { routesChanged := false for pos, route := range routes { - if route.isExitRoute() { + if route.IsExitRoute() { continue } @@ -321,7 +302,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { } if route.IsPrimary { - if route.Machine.isOnline() { + if route.Machine.IsOnline() { continue } @@ -332,7 +313,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { Msgf("machine offline, finding a new primary subnet") // find a new primary route - var newPrimaryRoutes []Route + var newPrimaryRoutes types.Routes err := hsdb.db. Preload("Machine"). Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", @@ -346,9 +327,9 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { return err } - var newPrimaryRoute *Route + var newPrimaryRoute *types.Route for pos, r := range newPrimaryRoutes { - if r.Machine.isOnline() { + if r.Machine.IsOnline() { newPrimaryRoute = &newPrimaryRoutes[pos] break @@ -399,27 +380,78 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { return nil } -func (rs Routes) toProto() []*v1.Route { - protoRoutes := []*v1.Route{} - - for _, route := range rs { - protoRoute := v1.Route{ - Id: uint64(route.ID), - Machine: route.Machine.toProto(), - Prefix: netip.Prefix(route.Prefix).String(), - Advertised: route.Advertised, - Enabled: route.Enabled, - IsPrimary: route.IsPrimary, - CreatedAt: timestamppb.New(route.CreatedAt), - UpdatedAt: timestamppb.New(route.UpdatedAt), - } - - if route.DeletedAt.Valid { - protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time) - } - - protoRoutes = append(protoRoutes, &protoRoute) +// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy. +func (hsdb *HSDatabase) EnableAutoApprovedRoutes( + aclPolicy *policy.ACLPolicy, + machine *types.Machine, +) error { + if len(machine.IPAddresses) == 0 { + return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs } - return protoRoutes + routes, err := hsdb.GetMachineAdvertisedRoutes(machine) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + log.Error(). + Caller(). + Err(err). + Str("machine", machine.Hostname). + Msg("Could not get advertised routes for machine") + + return err + } + + approvedRoutes := types.Routes{} + + for _, advertisedRoute := range routes { + if advertisedRoute.Enabled { + continue + } + + routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( + netip.Prefix(advertisedRoute.Prefix), + ) + if err != nil { + log.Err(err). + Str("advertisedRoute", advertisedRoute.String()). + Uint64("machineId", machine.ID). + Msg("Failed to resolve autoApprovers for advertised route") + + return err + } + + for _, approvedAlias := range routeApprovers { + if approvedAlias == machine.User.Name { + approvedRoutes = append(approvedRoutes, advertisedRoute) + } else { + // TODO(kradalby): figure out how to get this to depend on less stuff + approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain) + if err != nil { + log.Err(err). + Str("alias", approvedAlias). + Msg("Failed to expand alias when processing autoApprovers policy") + + return err + } + + // approvedIPs should contain all of machine's IPs if it matches the rule, so check for first + if approvedIps.Contains(machine.IPAddresses[0]) { + approvedRoutes = append(approvedRoutes, advertisedRoute) + } + } + } + } + + for _, approvedRoute := range approvedRoutes { + err := hsdb.EnableRoute(uint64(approvedRoute.ID)) + if err != nil { + log.Err(err). + Str("approvedRoute", approvedRoute.String()). + Uint64("machineId", machine.ID). + Msg("Failed to enable approved route") + + return err + } + } + + return nil } diff --git a/hscontrol/routes_test.go b/hscontrol/db/routes_test.go similarity index 60% rename from hscontrol/routes_test.go rename to hscontrol/db/routes_test.go index cf437a4..d281452 100644 --- a/hscontrol/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -1,9 +1,11 @@ -package hscontrol +package db import ( "net/netip" "time" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "tailscale.com/tailcfg" @@ -11,13 +13,13 @@ import ( ) func (s *Suite) TestGetRoutes(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_get_route_machine") + _, err = db.GetMachine("test", "test_get_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix("10.0.0.0/24") @@ -27,41 +29,43 @@ func (s *Suite) TestGetRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route}, } - machine := Machine{ + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_get_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), + HostInfo: types.HostInfo(hostInfo), } - app.db.db.Save(&machine) + db.db.Save(&machine) - err = app.db.processMachineRoutes(&machine) + err = db.ProcessMachineRoutes(&machine) c.Assert(err, check.IsNil) - advertisedRoutes, err := app.db.GetAdvertisedRoutes(&machine) + advertisedRoutes, err := db.GetAdvertisedRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(advertisedRoutes), check.Equals, 1) - err = app.db.enableRoutes(&machine, "192.168.0.0/24") + err = db.enableRoutes(&machine, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = app.db.enableRoutes(&machine, "10.0.0.0/24") + err = db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetEnableRoutes(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_enable_route_machine") + _, err = db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -78,65 +82,67 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route, route2}, } - machine := Machine{ + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), + HostInfo: types.HostInfo(hostInfo), } - app.db.db.Save(&machine) + db.db.Save(&machine) - err = app.db.processMachineRoutes(&machine) + err = db.ProcessMachineRoutes(&machine) c.Assert(err, check.IsNil) - availableRoutes, err := app.db.GetAdvertisedRoutes(&machine) + availableRoutes, err := db.GetAdvertisedRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(err, check.IsNil) c.Assert(len(availableRoutes), check.Equals, 2) - noEnabledRoutes, err := app.db.GetEnabledRoutes(&machine) + noEnabledRoutes, err := db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(noEnabledRoutes), check.Equals, 0) - err = app.db.enableRoutes(&machine, "192.168.0.0/24") + err = db.enableRoutes(&machine, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = app.db.enableRoutes(&machine, "10.0.0.0/24") + err = db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) - enabledRoutes, err := app.db.GetEnabledRoutes(&machine) + enabledRoutes, err := db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes), check.Equals, 1) // Adding it twice will just let it pass through - err = app.db.enableRoutes(&machine, "10.0.0.0/24") + err = db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) - enableRoutesAfterDoubleApply, err := app.db.GetEnabledRoutes(&machine) + enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) - err = app.db.enableRoutes(&machine, "150.0.10.0/25") + err = db.enableRoutes(&machine, "150.0.10.0/25") c.Assert(err, check.IsNil) - enabledRoutesWithAdditionalRoute, err := app.db.GetEnabledRoutes(&machine) + enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) + + c.Assert(channelUpdates, check.Equals, int32(3)) } func (s *Suite) TestIsUniquePrefix(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_enable_route_machine") + _, err = db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -152,75 +158,77 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { hostInfo1 := tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{route, route2}, } - machine1 := Machine{ + machine1 := types.Machine{ ID: 1, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo1), + HostInfo: types.HostInfo(hostInfo1), } - app.db.db.Save(&machine1) + db.db.Save(&machine1) - err = app.db.processMachineRoutes(&machine1) + err = db.ProcessMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, route.String()) + err = db.enableRoutes(&machine1, route.String()) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, route2.String()) + err = db.enableRoutes(&machine1, route2.String()) c.Assert(err, check.IsNil) hostInfo2 := tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{route2}, } - machine2 := Machine{ + machine2 := types.Machine{ ID: 2, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo2), + HostInfo: types.HostInfo(hostInfo2), } - app.db.db.Save(&machine2) + db.db.Save(&machine2) - err = app.db.processMachineRoutes(&machine2) + err = db.ProcessMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine2, route2.String()) + err = db.enableRoutes(&machine2, route2.String()) c.Assert(err, check.IsNil) - enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err := db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) + enabledRoutes2, err := db.GetEnabledRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes2), check.Equals, 1) - routes, err := app.db.getMachinePrimaryRoutes(&machine1) + routes, err := db.GetMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) - routes, err = app.db.getMachinePrimaryRoutes(&machine2) + routes, err = db.GetMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) + + c.Assert(channelUpdates, check.Equals, int32(3)) } func (s *Suite) TestSubnetFailover(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_enable_route_machine") + _, err = db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -238,134 +246,136 @@ func (s *Suite) TestSubnetFailover(c *check.C) { } now := time.Now() - machine1 := Machine{ + machine1 := types.Machine{ ID: 1, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo1), + HostInfo: types.HostInfo(hostInfo1), LastSeen: &now, } - app.db.db.Save(&machine1) + db.db.Save(&machine1) - err = app.db.processMachineRoutes(&machine1) + err = db.ProcessMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, prefix.String()) + err = db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, prefix2.String()) + err = db.enableRoutes(&machine1, prefix2.String()) c.Assert(err, check.IsNil) - err = app.db.handlePrimarySubnetFailover() + err = db.HandlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err := db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - route, err := app.db.getPrimaryRoute(prefix) + route, err := db.getPrimaryRoute(prefix) c.Assert(err, check.IsNil) c.Assert(route.MachineID, check.Equals, machine1.ID) hostInfo2 := tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{prefix2}, } - machine2 := Machine{ + machine2 := types.Machine{ ID: 2, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo2), + HostInfo: types.HostInfo(hostInfo2), LastSeen: &now, } - app.db.db.Save(&machine2) + db.db.Save(&machine2) - err = app.db.processMachineRoutes(&machine2) + err = db.ProcessMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine2, prefix2.String()) + err = db.enableRoutes(&machine2, prefix2.String()) c.Assert(err, check.IsNil) - err = app.db.handlePrimarySubnetFailover() + err = db.HandlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err = db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) + enabledRoutes2, err := db.GetEnabledRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes2), check.Equals, 1) - routes, err := app.db.getMachinePrimaryRoutes(&machine1) + routes, err := db.GetMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) - routes, err = app.db.getMachinePrimaryRoutes(&machine2) + routes, err = db.GetMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) // lets make machine1 lastseen 10 mins ago before := now.Add(-10 * time.Minute) machine1.LastSeen = &before - err = app.db.db.Save(&machine1).Error + err = db.db.Save(&machine1).Error c.Assert(err, check.IsNil) - err = app.db.handlePrimarySubnetFailover() + err = db.HandlePrimarySubnetFailover() c.Assert(err, check.IsNil) - routes, err = app.db.getMachinePrimaryRoutes(&machine1) + routes, err = db.GetMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 1) - routes, err = app.db.getMachinePrimaryRoutes(&machine2) + routes, err = db.GetMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 1) - machine2.HostInfo = HostInfo(tailcfg.Hostinfo{ + machine2.HostInfo = types.HostInfo(tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{prefix, prefix2}, }) - err = app.db.db.Save(&machine2).Error + err = db.db.Save(&machine2).Error c.Assert(err, check.IsNil) - err = app.db.processMachineRoutes(&machine2) + err = db.ProcessMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine2, prefix.String()) + err = db.enableRoutes(&machine2, prefix.String()) c.Assert(err, check.IsNil) - err = app.db.handlePrimarySubnetFailover() + err = db.HandlePrimarySubnetFailover() c.Assert(err, check.IsNil) - routes, err = app.db.getMachinePrimaryRoutes(&machine1) + routes, err = db.GetMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) - routes, err = app.db.getMachinePrimaryRoutes(&machine2) + routes, err = db.GetMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) + + c.Assert(channelUpdates, check.Equals, int32(6)) } // TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node, // including both the primary routes the node is responsible for, and the // exit node routes if enabled. func (s *Suite) TestAllowedIPRoutes(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_enable_route_machine") + _, err = db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -397,35 +407,35 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { machineKey := key.NewMachine() now := time.Now() - machine1 := Machine{ + machine1 := types.Machine{ ID: 1, MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()), Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo1), + HostInfo: types.HostInfo(hostInfo1), LastSeen: &now, } - app.db.db.Save(&machine1) + db.db.Save(&machine1) - err = app.db.processMachineRoutes(&machine1) + err = db.ProcessMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, prefix.String()) + err = db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) // We do not enable this one on purpose to test that it is not enabled - // err = app.db.enableRoutes(&machine1, prefix2.String()) + // err = db.enableRoutes(&machine1, prefix2.String()) // c.Assert(err, check.IsNil) - routes, err := app.db.GetMachineRoutes(&machine1) + routes, err := db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) for _, route := range routes { - if route.isExitRoute() { - err = app.db.EnableRoute(uint64(route.ID)) + if route.IsExitRoute() { + err = db.EnableRoute(uint64(route.ID)) c.Assert(err, check.IsNil) // We only enable one exit route, so we can test that both are enabled @@ -433,14 +443,14 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { } } - err = app.db.handlePrimarySubnetFailover() + err = db.HandlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err := db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 3) - peer, err := app.db.toNode(machine1, app.aclPolicy, "headscale.net", nil) + peer, err := db.TailNode(machine1, &policy.ACLPolicy{}, nil) c.Assert(err, check.IsNil) c.Assert(len(peer.AllowedIPs), check.Equals, 3) @@ -461,44 +471,46 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { // Now we disable only one of the exit routes // and we see if both are disabled - var exitRouteV4 Route + var exitRouteV4 types.Route for _, route := range routes { - if route.isExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 { + if route.IsExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 { exitRouteV4 = route break } } - err = app.db.DisableRoute(uint64(exitRouteV4.ID)) + err = db.DisableRoute(uint64(exitRouteV4.ID)) c.Assert(err, check.IsNil) - enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err = db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 1) // and now we delete only one of the exit routes // and we check if both are deleted - routes, err = app.db.GetMachineRoutes(&machine1) + routes, err = db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 4) - err = app.db.DeleteRoute(uint64(exitRouteV4.ID)) + err = db.DeleteRoute(uint64(exitRouteV4.ID)) c.Assert(err, check.IsNil) - routes, err = app.db.GetMachineRoutes(&machine1) + routes, err = db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) + + c.Assert(channelUpdates, check.Equals, int32(2)) } func (s *Suite) TestDeleteRoutes(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_enable_route_machine") + _, err = db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -516,36 +528,38 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { } now := time.Now() - machine1 := Machine{ + machine1 := types.Machine{ ID: 1, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo1), + HostInfo: types.HostInfo(hostInfo1), LastSeen: &now, } - app.db.db.Save(&machine1) + db.db.Save(&machine1) - err = app.db.processMachineRoutes(&machine1) + err = db.ProcessMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, prefix.String()) + err = db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, prefix2.String()) + err = db.enableRoutes(&machine1, prefix2.String()) c.Assert(err, check.IsNil) - routes, err := app.db.GetMachineRoutes(&machine1) + routes, err := db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.db.DeleteRoute(uint64(routes[0].ID)) + err = db.DeleteRoute(uint64(routes[0].ID)) c.Assert(err, check.IsNil) - enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err := db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 1) + + c.Assert(channelUpdates, check.Equals, int32(2)) } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go new file mode 100644 index 0000000..01541b9 --- /dev/null +++ b/hscontrol/db/suite_test.go @@ -0,0 +1,74 @@ +package db + +import ( + "net/netip" + "os" + "sync/atomic" + "testing" + + "gopkg.in/check.v1" +) + +func Test(t *testing.T) { + check.TestingT(t) +} + +var _ = check.Suite(&Suite{}) + +type Suite struct{} + +var ( + tmpDir string + db *HSDatabase + + // channelUpdates counts the number of times + // either of the channels was notified. + channelUpdates int32 +) + +func (s *Suite) SetUpTest(c *check.C) { + atomic.StoreInt32(&channelUpdates, 0) + s.ResetDB(c) +} + +func (s *Suite) TearDownTest(c *check.C) { + os.RemoveAll(tmpDir) +} + +func notificationSink(c <-chan struct{}) { + for { + <-c + atomic.AddInt32(&channelUpdates, 1) + } +} + +func (s *Suite) ResetDB(c *check.C) { + if len(tmpDir) != 0 { + os.RemoveAll(tmpDir) + } + var err error + tmpDir, err = os.MkdirTemp("", "autoygg-client-test") + if err != nil { + c.Fatal(err) + } + + sink := make(chan struct{}) + + go notificationSink(sink) + + db, err = NewHeadscaleDatabase( + "sqlite3", + tmpDir+"/headscale_test.db", + false, + false, + sink, + sink, + []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + }, + "", + ) + if err != nil { + c.Fatal(err) + } +} diff --git a/hscontrol/users.go b/hscontrol/db/users.go similarity index 50% rename from hscontrol/users.go rename to hscontrol/db/users.go index fb3cea9..e0ffd19 100644 --- a/hscontrol/users.go +++ b/hscontrol/db/users.go @@ -1,17 +1,12 @@ -package hscontrol +package db import ( "errors" "fmt" - "regexp" - "strconv" - "strings" - "time" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" - "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" "tailscale.com/tailcfg" ) @@ -20,33 +15,16 @@ var ( ErrUserExists = errors.New("user already exists") ErrUserNotFound = errors.New("user not found") ErrUserStillHasNodes = errors.New("user not empty: node(s) found") - ErrInvalidUserName = errors.New("invalid user name") ) -const ( - // value related to RFC 1123 and 952. - labelHostnameLength = 63 -) - -var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") - -// User is the way Headscale implements the concept of users in Tailscale -// -// At the end of the day, users in Tailscale are some kind of 'bubbles' or users -// that contain our machines. -type User struct { - gorm.Model - Name string `gorm:"unique"` -} - // CreateUser creates a new User. Returns error if could not be created // or another user already exists. -func (hsdb *HSDatabase) CreateUser(name string) (*User, error) { - err := CheckForFQDNRules(name) +func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { + err := util.CheckForFQDNRules(name) if err != nil { return nil, err } - user := User{} + user := types.User{} if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { return nil, ErrUserExists } @@ -105,7 +83,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { if err != nil { return err } - err = CheckForFQDNRules(newName) + err = util.CheckForFQDNRules(newName) if err != nil { return err } @@ -127,8 +105,8 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { } // GetUser fetches a user by name. -func (hsdb *HSDatabase) GetUser(name string) (*User, error) { - user := User{} +func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { + user := types.User{} if result := hsdb.db.First(&user, "name = ?", name); errors.Is( result.Error, gorm.ErrRecordNotFound, @@ -140,8 +118,8 @@ func (hsdb *HSDatabase) GetUser(name string) (*User, error) { } // ListUsers gets all the existing users. -func (hsdb *HSDatabase) ListUsers() ([]User, error) { - users := []User{} +func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { + users := []types.User{} if err := hsdb.db.Find(&users).Error; err != nil { return nil, err } @@ -150,8 +128,8 @@ func (hsdb *HSDatabase) ListUsers() ([]User, error) { } // ListMachinesByUser gets all the nodes in a given user. -func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { - err := CheckForFQDNRules(name) +func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) { + err := util.CheckForFQDNRules(name) if err != nil { return nil, err } @@ -160,8 +138,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { return nil, err } - machines := []Machine{} - if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil { + machines := types.Machines{} + if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Machine{UserID: user.ID}).Find(&machines).Error; err != nil { return nil, err } @@ -169,8 +147,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { } // SetMachineUser assigns a Machine to a user. -func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error { - err := CheckForFQDNRules(username) +func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error { + err := util.CheckForFQDNRules(username) if err != nil { return err } @@ -186,37 +164,11 @@ func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error return nil } -func (n *User) toTailscaleUser() *tailcfg.User { - user := tailcfg.User{ - ID: tailcfg.UserID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, - ProfilePicURL: "", - Domain: "headscale.net", - Logins: []tailcfg.LoginID{}, - Created: time.Time{}, - } - - return &user -} - -func (n *User) toTailscaleLogin() *tailcfg.Login { - login := tailcfg.Login{ - ID: tailcfg.LoginID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, - ProfilePicURL: "", - Domain: "headscale.net", - } - - return &login -} - -func (hsdb *HSDatabase) getMapResponseUserProfiles( - machine Machine, - peers Machines, +func (hsdb *HSDatabase) GetMapResponseUserProfiles( + machine types.Machine, + peers types.Machines, ) []tailcfg.UserProfile { - userMap := make(map[string]User) + userMap := make(map[string]types.User) userMap[machine.User.Name] = machine.User for _, peer := range peers { userMap[peer.User.Name] = peer.User // not worth checking if already is there @@ -240,63 +192,3 @@ func (hsdb *HSDatabase) getMapResponseUserProfiles( return profiles } - -func (n *User) toProto() *v1.User { - return &v1.User{ - Id: strconv.FormatUint(uint64(n.ID), util.Base10), - Name: n.Name, - CreatedAt: timestamppb.New(n.CreatedAt), - } -} - -// NormalizeToFQDNRules will replace forbidden chars in user -// it can also return an error if the user doesn't respect RFC 952 and 1123. -func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { - name = strings.ToLower(name) - name = strings.ReplaceAll(name, "'", "") - atIdx := strings.Index(name, "@") - if stripEmailDomain && atIdx > 0 { - name = name[:atIdx] - } else { - name = strings.ReplaceAll(name, "@", ".") - } - name = invalidCharsInUserRegex.ReplaceAllString(name, "-") - - for _, elt := range strings.Split(name, ".") { - if len(elt) > labelHostnameLength { - return "", fmt.Errorf( - "label %v is more than 63 chars: %w", - elt, - ErrInvalidUserName, - ) - } - } - - return name, nil -} - -func CheckForFQDNRules(name string) error { - if len(name) > labelHostnameLength { - return fmt.Errorf( - "DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w", - name, - ErrInvalidUserName, - ) - } - if strings.ToLower(name) != name { - return fmt.Errorf( - "DNS segment should be lowercase. %v doesn't comply with this rule: %w", - name, - ErrInvalidUserName, - ) - } - if invalidCharsInUserRegex.MatchString(name) { - return fmt.Errorf( - "DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", - name, - ErrInvalidUserName, - ) - } - - return nil -} diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go new file mode 100644 index 0000000..02c0a2a --- /dev/null +++ b/hscontrol/db/users_test.go @@ -0,0 +1,277 @@ +package db + +import ( + "net/netip" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "gopkg.in/check.v1" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func (s *Suite) TestCreateAndDestroyUser(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + c.Assert(user.Name, check.Equals, "test") + + users, err := db.ListUsers() + c.Assert(err, check.IsNil) + c.Assert(len(users), check.Equals, 1) + + err = db.DestroyUser("test") + c.Assert(err, check.IsNil) + + _, err = db.GetUser("test") + c.Assert(err, check.NotNil) +} + +func (s *Suite) TestDestroyUserErrors(c *check.C) { + err := db.DestroyUser("test") + c.Assert(err, check.Equals, ErrUserNotFound) + + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + err = db.DestroyUser("test") + c.Assert(err, check.IsNil) + + result := db.db.Preload("User").First(&pak, "key = ?", pak.Key) + // destroying a user also deletes all associated preauthkeys + c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) + + user, err = db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + + err = db.DestroyUser("test") + c.Assert(err, check.Equals, ErrUserStillHasNodes) +} + +func (s *Suite) TestRenameUser(c *check.C) { + userTest, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + c.Assert(userTest.Name, check.Equals, "test") + + users, err := db.ListUsers() + c.Assert(err, check.IsNil) + c.Assert(len(users), check.Equals, 1) + + err = db.RenameUser("test", "test-renamed") + c.Assert(err, check.IsNil) + + _, err = db.GetUser("test") + c.Assert(err, check.Equals, ErrUserNotFound) + + _, err = db.GetUser("test-renamed") + c.Assert(err, check.IsNil) + + err = db.RenameUser("test-does-not-exit", "test") + c.Assert(err, check.Equals, ErrUserNotFound) + + userTest2, err := db.CreateUser("test2") + c.Assert(err, check.IsNil) + c.Assert(userTest2.Name, check.Equals, "test2") + + err = db.RenameUser("test2", "test-renamed") + c.Assert(err, check.Equals, ErrUserExists) +} + +func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { + userShared1, err := db.CreateUser("shared1") + c.Assert(err, check.IsNil) + + userShared2, err := db.CreateUser("shared2") + c.Assert(err, check.IsNil) + + userShared3, err := db.CreateUser("shared3") + c.Assert(err, check.IsNil) + + preAuthKeyShared1, err := db.CreatePreAuthKey( + userShared1.Name, + false, + false, + nil, + nil, + ) + c.Assert(err, check.IsNil) + + preAuthKeyShared2, err := db.CreatePreAuthKey( + userShared2.Name, + false, + false, + nil, + nil, + ) + c.Assert(err, check.IsNil) + + preAuthKeyShared3, err := db.CreatePreAuthKey( + userShared3.Name, + false, + false, + nil, + nil, + ) + c.Assert(err, check.IsNil) + + preAuthKey2Shared1, err := db.CreatePreAuthKey( + userShared1.Name, + false, + false, + nil, + nil, + ) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") + c.Assert(err, check.NotNil) + + machineInShared1 := &types.Machine{ + ID: 1, + MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", + NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", + DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", + Hostname: "test_get_shared_nodes_1", + UserID: userShared1.ID, + User: *userShared1, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, + AuthKeyID: uint(preAuthKeyShared1.ID), + } + db.db.Save(machineInShared1) + + _, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname) + c.Assert(err, check.IsNil) + + machineInShared2 := &types.Machine{ + ID: 2, + MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + Hostname: "test_get_shared_nodes_2", + UserID: userShared2.ID, + User: *userShared2, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, + AuthKeyID: uint(preAuthKeyShared2.ID), + } + db.db.Save(machineInShared2) + + _, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname) + c.Assert(err, check.IsNil) + + machineInShared3 := &types.Machine{ + ID: 3, + MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + Hostname: "test_get_shared_nodes_3", + UserID: userShared3.ID, + User: *userShared3, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, + AuthKeyID: uint(preAuthKeyShared3.ID), + } + db.db.Save(machineInShared3) + + _, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname) + c.Assert(err, check.IsNil) + + machine2InShared1 := &types.Machine{ + ID: 4, + MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + Hostname: "test_get_shared_nodes_4", + UserID: userShared1.ID, + User: *userShared1, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, + AuthKeyID: uint(preAuthKey2Shared1.ID), + } + db.db.Save(machine2InShared1) + + peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1) + c.Assert(err, check.IsNil) + + userProfiles := db.GetMapResponseUserProfiles( + *machineInShared1, + peersOfMachine1InShared1, + ) + + c.Assert(len(userProfiles), check.Equals, 3) + + found := false + for _, userProfiles := range userProfiles { + if userProfiles.DisplayName == userShared1.Name { + found = true + + break + } + } + c.Assert(found, check.Equals, true) + + found = false + for _, userProfile := range userProfiles { + if userProfile.DisplayName == userShared2.Name { + found = true + + break + } + } + c.Assert(found, check.Equals, true) +} + +func (s *Suite) TestSetMachineUser(c *check.C) { + oldUser, err := db.CreateUser("old") + c.Assert(err, check.IsNil) + + newUser, err := db.CreateUser("new") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: oldUser.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + c.Assert(machine.UserID, check.Equals, oldUser.ID) + + err = db.SetMachineUser(&machine, newUser.Name) + c.Assert(err, check.IsNil) + c.Assert(machine.UserID, check.Equals, newUser.ID) + c.Assert(machine.User.Name, check.Equals, newUser.Name) + + err = db.SetMachineUser(&machine, "non-existing-user") + c.Assert(err, check.Equals, ErrUserNotFound) + + err = db.SetMachineUser(&machine, newUser.Name) + c.Assert(err, check.IsNil) + c.Assert(machine.UserID, check.Equals, newUser.ID) + c.Assert(machine.User.Name, check.Equals, newUser.Name) +} diff --git a/hscontrol/dns.go b/hscontrol/dns.go index 72c5b03..2c611f1 100644 --- a/hscontrol/dns.go +++ b/hscontrol/dns.go @@ -7,6 +7,7 @@ import ( "strings" mapset "github.com/deckarep/golang-set/v2" + "github.com/juanfont/headscale/hscontrol/types" "go4.org/netipx" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" @@ -165,7 +166,7 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { // // This will produce a resolver like: // `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` -func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) { +func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) { for _, resolver := range resolvers { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { attrs := url.Values{ @@ -185,8 +186,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) { func getMapResponseDNSConfig( dnsConfigOrig *tailcfg.DNSConfig, baseDomain string, - machine Machine, - peers Machines, + machine types.Machine, + peers types.Machines, ) *tailcfg.DNSConfig { var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone() if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled @@ -200,7 +201,7 @@ func getMapResponseDNSConfig( ), ) - userSet := mapset.NewSet[User]() + userSet := mapset.NewSet[types.User]() userSet.Add(machine.User) for _, p := range peers { userSet.Add(p.User) diff --git a/hscontrol/dns_test.go b/hscontrol/dns_test.go index 671a712..6bee0ea 100644 --- a/hscontrol/dns_test.go +++ b/hscontrol/dns_test.go @@ -4,6 +4,8 @@ import ( "fmt" "net/netip" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" @@ -160,7 +162,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") c.Assert(err, check.NotNil) - machineInShared1 := &Machine{ + machineInShared1 := &types.Machine{ ID: 1, MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", @@ -168,16 +170,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_1", UserID: userShared1.ID, User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), } - app.db.db.Save(machineInShared1) + err = app.db.MachineSave(machineInShared1) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) c.Assert(err, check.IsNil) - machineInShared2 := &Machine{ + machineInShared2 := &types.Machine{ ID: 2, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -185,16 +188,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_2", UserID: userShared2.ID, User: *userShared2, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), } - app.db.db.Save(machineInShared2) + err = app.db.MachineSave(machineInShared2) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) c.Assert(err, check.IsNil) - machineInShared3 := &Machine{ + machineInShared3 := &types.Machine{ ID: 3, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -202,16 +206,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_3", UserID: userShared3.ID, User: *userShared3, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), } - app.db.db.Save(machineInShared3) + err = app.db.MachineSave(machineInShared3) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) c.Assert(err, check.IsNil) - machine2InShared1 := &Machine{ + machine2InShared1 := &types.Machine{ ID: 4, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -219,11 +224,12 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_4", UserID: userShared1.ID, User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, AuthKeyID: uint(PreAuthKey2InShared1.ID), } - app.db.db.Save(machine2InShared1) + err = app.db.MachineSave(machine2InShared1) + c.Assert(err, check.IsNil) baseDomain := "foobar.headscale.net" dnsConfigOrig := tailcfg.DNSConfig{ @@ -232,7 +238,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Proxied: true, } - peersOfMachineInShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) + peersOfMachineInShared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1) c.Assert(err, check.IsNil) dnsConfig := getMapResponseDNSConfig( @@ -307,7 +313,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") c.Assert(err, check.NotNil) - machineInShared1 := &Machine{ + machineInShared1 := &types.Machine{ ID: 1, MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", @@ -315,16 +321,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_1", UserID: userShared1.ID, User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), } - app.db.db.Save(machineInShared1) + err = app.db.MachineSave(machineInShared1) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) c.Assert(err, check.IsNil) - machineInShared2 := &Machine{ + machineInShared2 := &types.Machine{ ID: 2, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -332,16 +339,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_2", UserID: userShared2.ID, User: *userShared2, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), } - app.db.db.Save(machineInShared2) + err = app.db.MachineSave(machineInShared2) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) c.Assert(err, check.IsNil) - machineInShared3 := &Machine{ + machineInShared3 := &types.Machine{ ID: 3, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -349,16 +357,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_3", UserID: userShared3.ID, User: *userShared3, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), } - app.db.db.Save(machineInShared3) + err = app.db.MachineSave(machineInShared3) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) c.Assert(err, check.IsNil) - machine2InShared1 := &Machine{ + machine2InShared1 := &types.Machine{ ID: 4, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -366,11 +375,12 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_4", UserID: userShared1.ID, User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, AuthKeyID: uint(preAuthKey2InShared1.ID), } - app.db.db.Save(machine2InShared1) + err = app.db.MachineSave(machine2InShared1) + c.Assert(err, check.IsNil) baseDomain := "foobar.headscale.net" dnsConfigOrig := tailcfg.DNSConfig{ @@ -379,7 +389,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Proxied: false, } - peersOfMachine1Shared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) + peersOfMachine1Shared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1) c.Assert(err, check.IsNil) dnsConfig := getMapResponseDNSConfig( diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 4a26d08..8adf871 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -8,6 +8,7 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc/codes" @@ -36,7 +37,7 @@ func (api headscaleV1APIServer) GetUser( return nil, err } - return &v1.GetUserResponse{User: user.toProto()}, nil + return &v1.GetUserResponse{User: user.Proto()}, nil } func (api headscaleV1APIServer) CreateUser( @@ -48,7 +49,7 @@ func (api headscaleV1APIServer) CreateUser( return nil, err } - return &v1.CreateUserResponse{User: user.toProto()}, nil + return &v1.CreateUserResponse{User: user.Proto()}, nil } func (api headscaleV1APIServer) RenameUser( @@ -65,7 +66,7 @@ func (api headscaleV1APIServer) RenameUser( return nil, err } - return &v1.RenameUserResponse{User: user.toProto()}, nil + return &v1.RenameUserResponse{User: user.Proto()}, nil } func (api headscaleV1APIServer) DeleteUser( @@ -91,7 +92,7 @@ func (api headscaleV1APIServer) ListUsers( response := make([]*v1.User, len(users)) for index, user := range users { - response[index] = user.toProto() + response[index] = user.Proto() } log.Trace().Caller().Interface("users", response).Msg("") @@ -128,7 +129,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey( return nil, err } - return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.toProto()}, nil + return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.Proto()}, nil } func (api headscaleV1APIServer) ExpirePreAuthKey( @@ -159,7 +160,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys( response := make([]*v1.PreAuthKey, len(preAuthKeys)) for index, key := range preAuthKeys { - response[index] = key.toProto() + response[index] = key.Proto() } return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil @@ -179,13 +180,13 @@ func (api headscaleV1APIServer) RegisterMachine( request.GetKey(), request.GetUser(), nil, - RegisterMethodCLI, + util.RegisterMethodCLI, ) if err != nil { return nil, err } - return &v1.RegisterMachineResponse{Machine: machine.toProto()}, nil + return &v1.RegisterMachineResponse{Machine: machine.Proto()}, nil } func (api headscaleV1APIServer) GetMachine( @@ -197,7 +198,7 @@ func (api headscaleV1APIServer) GetMachine( return nil, err } - return &v1.GetMachineResponse{Machine: machine.toProto()}, nil + return &v1.GetMachineResponse{Machine: machine.Proto()}, nil } func (api headscaleV1APIServer) SetTags( @@ -218,7 +219,7 @@ func (api headscaleV1APIServer) SetTags( } } - err = api.h.db.SetTags(machine, request.GetTags(), api.h.UpdateACLRules) + err = api.h.db.SetTags(machine, request.GetTags()) if err != nil { return &v1.SetTagsResponse{ Machine: nil, @@ -230,7 +231,7 @@ func (api headscaleV1APIServer) SetTags( Strs("tags", request.GetTags()). Msg("Changing tags of machine") - return &v1.SetTagsResponse{Machine: machine.toProto()}, nil + return &v1.SetTagsResponse{Machine: machine.Proto()}, nil } func validateTag(tag string) error { @@ -283,7 +284,7 @@ func (api headscaleV1APIServer) ExpireMachine( Time("expiry", *machine.Expiry). Msg("machine expired") - return &v1.ExpireMachineResponse{Machine: machine.toProto()}, nil + return &v1.ExpireMachineResponse{Machine: machine.Proto()}, nil } func (api headscaleV1APIServer) RenameMachine( @@ -308,7 +309,7 @@ func (api headscaleV1APIServer) RenameMachine( Str("new_name", request.GetNewName()). Msg("machine renamed") - return &v1.RenameMachineResponse{Machine: machine.toProto()}, nil + return &v1.RenameMachineResponse{Machine: machine.Proto()}, nil } func (api headscaleV1APIServer) ListMachines( @@ -323,7 +324,7 @@ func (api headscaleV1APIServer) ListMachines( response := make([]*v1.Machine, len(machines)) for index, machine := range machines { - response[index] = machine.toProto() + response[index] = machine.Proto() } return &v1.ListMachinesResponse{Machines: response}, nil @@ -336,9 +337,8 @@ func (api headscaleV1APIServer) ListMachines( response := make([]*v1.Machine, len(machines)) for index, machine := range machines { - m := machine.toProto() - validTags, invalidTags := getTags( - api.h.aclPolicy, + m := machine.Proto() + validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine( machine, api.h.cfg.OIDC.StripEmaildomain, ) @@ -364,7 +364,7 @@ func (api headscaleV1APIServer) MoveMachine( return nil, err } - return &v1.MoveMachineResponse{Machine: machine.toProto()}, nil + return &v1.MoveMachineResponse{Machine: machine.Proto()}, nil } func (api headscaleV1APIServer) GetRoutes( @@ -377,7 +377,7 @@ func (api headscaleV1APIServer) GetRoutes( } return &v1.GetRoutesResponse{ - Routes: Routes(routes).toProto(), + Routes: types.Routes(routes).Proto(), }, nil } @@ -420,7 +420,7 @@ func (api headscaleV1APIServer) GetMachineRoutes( } return &v1.GetMachineRoutesResponse{ - Routes: Routes(routes).toProto(), + Routes: types.Routes(routes).Proto(), }, nil } @@ -459,7 +459,7 @@ func (api headscaleV1APIServer) ExpireApiKey( ctx context.Context, request *v1.ExpireApiKeyRequest, ) (*v1.ExpireApiKeyResponse, error) { - var apiKey *APIKey + var apiKey *types.APIKey var err error apiKey, err = api.h.db.GetAPIKey(request.Prefix) @@ -486,7 +486,7 @@ func (api headscaleV1APIServer) ListApiKeys( response := make([]*v1.ApiKey, len(apiKeys)) for index, key := range apiKeys { - response[index] = key.toProto() + response[index] = key.Proto() } return &v1.ListApiKeysResponse{ApiKeys: response}, nil @@ -524,7 +524,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( return nil, err } - newMachine := Machine{ + newMachine := types.Machine{ MachineKey: request.GetKey(), Hostname: request.GetName(), GivenName: givenName, @@ -534,7 +534,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( LastSeen: &time.Time{}, LastSuccessfulUpdate: &time.Time{}, - HostInfo: HostInfo(hostinfo), + HostInfo: types.HostInfo(hostinfo), } nodeKey := key.NodePublic{} @@ -549,7 +549,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( registerCacheExpiration, ) - return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil + return &v1.DebugCreateMachineResponse{Machine: newMachine.Proto()}, nil } func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} diff --git a/hscontrol/machine_test.go b/hscontrol/machine_test.go deleted file mode 100644 index 0e7d7de..0000000 --- a/hscontrol/machine_test.go +++ /dev/null @@ -1,1386 +0,0 @@ -package hscontrol - -import ( - "fmt" - "net/netip" - "reflect" - "regexp" - "strconv" - "testing" - "time" - - "github.com/juanfont/headscale/hscontrol/util" - "gopkg.in/check.v1" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -func (s *Suite) TestGetMachine(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.NotNil) - - machine := &Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(machine) - - _, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetMachineByID(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachineByID(0) - c.Assert(err, check.NotNil) - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - _, err = app.db.GetMachineByID(0) - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetMachineByNodeKey(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachineByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - machine := Machine{ - ID: 0, - MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - _, err = app.db.GetMachineByNodeKey(nodeKey.Public()) - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachineByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - oldNodeKey := key.NewNode() - - machineKey := key.NewMachine() - - machine := Machine{ - ID: 0, - MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - _, err = app.db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestDeleteMachine(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(1), - } - app.db.db.Save(&machine) - - err = app.db.DeleteMachine(&machine) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(user.Name, "testmachine") - c.Assert(err, check.NotNil) -} - -func (s *Suite) TestHardDeleteMachine(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine3", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(1), - } - app.db.db.Save(&machine) - - err = app.db.HardDeleteMachine(&machine) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(user.Name, "testmachine3") - c.Assert(err, check.NotNil) -} - -func (s *Suite) TestListPeers(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachineByID(0) - c.Assert(err, check.NotNil) - - for index := 0; index <= 10; index++ { - machine := Machine{ - ID: uint64(index), - MachineKey: "foo" + strconv.Itoa(index), - NodeKey: "bar" + strconv.Itoa(index), - DiscoKey: "faa" + strconv.Itoa(index), - Hostname: "testmachine" + strconv.Itoa(index), - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - } - - machine0ByID, err := app.db.GetMachineByID(0) - c.Assert(err, check.IsNil) - - peersOfMachine0, err := app.db.ListPeers(machine0ByID) - c.Assert(err, check.IsNil) - - c.Assert(len(peersOfMachine0), check.Equals, 9) - c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2") - c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7") - c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10") -} - -func (s *Suite) TestGetACLFilteredPeers(c *check.C) { - type base struct { - user *User - key *PreAuthKey - } - - stor := make([]base, 0) - - for _, name := range []string{"test", "admin"} { - user, err := app.db.CreateUser(name) - c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - stor = append(stor, base{user, pak}) - } - - _, err := app.db.GetMachineByID(0) - c.Assert(err, check.NotNil) - - for index := 0; index <= 10; index++ { - machine := Machine{ - ID: uint64(index), - MachineKey: "foo" + strconv.Itoa(index), - NodeKey: "bar" + strconv.Itoa(index), - DiscoKey: "faa" + strconv.Itoa(index), - IPAddresses: MachineAddresses{ - netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))), - }, - Hostname: "testmachine" + strconv.Itoa(index), - UserID: stor[index%2].user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(stor[index%2].key.ID), - } - app.db.db.Save(&machine) - } - - app.aclPolicy = &ACLPolicy{ - Groups: map[string][]string{ - "group:test": {"admin"}, - }, - Hosts: map[string]netip.Prefix{}, - TagOwners: map[string][]string{}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"admin"}, - Destinations: []string{"*:*"}, - }, - { - Action: "accept", - Sources: []string{"test"}, - Destinations: []string{"test:*"}, - }, - }, - Tests: []ACLTest{}, - } - - err = app.UpdateACLRules() - c.Assert(err, check.IsNil) - - adminMachine, err := app.db.GetMachineByID(1) - c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User) - c.Assert(err, check.IsNil) - - testMachine, err := app.db.GetMachineByID(2) - c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User) - c.Assert(err, check.IsNil) - - machines, err := app.db.ListMachines() - c.Assert(err, check.IsNil) - - peersOfTestMachine := app.db.filterMachinesByACL(app.aclRules, testMachine, machines) - peersOfAdminMachine := app.db.filterMachinesByACL(app.aclRules, adminMachine, machines) - - c.Log(peersOfTestMachine) - c.Assert(len(peersOfTestMachine), check.Equals, 9) - c.Assert(peersOfTestMachine[0].Hostname, check.Equals, "testmachine1") - c.Assert(peersOfTestMachine[1].Hostname, check.Equals, "testmachine3") - c.Assert(peersOfTestMachine[3].Hostname, check.Equals, "testmachine5") - - c.Log(peersOfAdminMachine) - c.Assert(len(peersOfAdminMachine), check.Equals, 9) - c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2") - c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4") - c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7") -} - -func (s *Suite) TestExpireMachine(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.NotNil) - - machine := &Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - Expiry: &time.Time{}, - } - app.db.db.Save(machine) - - machineFromDB, err := app.db.GetMachine("test", "testmachine") - c.Assert(err, check.IsNil) - c.Assert(machineFromDB, check.NotNil) - - c.Assert(machineFromDB.isExpired(), check.Equals, false) - - err = app.db.ExpireMachine(machineFromDB) - c.Assert(err, check.IsNil) - - c.Assert(machineFromDB.isExpired(), check.Equals, true) -} - -func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { - input := MachineAddresses([]netip.Addr{ - netip.MustParseAddr("192.0.2.1"), - netip.MustParseAddr("2001:db8::1"), - }) - serialized, err := input.Value() - c.Assert(err, check.IsNil) - if serial, ok := serialized.(string); ok { - c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1") - } - - var deserialized MachineAddresses - err = deserialized.Scan(serialized) - c.Assert(err, check.IsNil) - - c.Assert(len(deserialized), check.Equals, len(input)) - for i := range deserialized { - c.Assert(deserialized[i], check.Equals, input[i]) - } -} - -func (s *Suite) TestGenerateGivenName(c *check.C) { - user1, err := app.db.CreateUser("user-1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user1.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user-1", "testmachine") - c.Assert(err, check.NotNil) - - machine := &Machine{ - ID: 0, - MachineKey: "machine-key-1", - NodeKey: "node-key-1", - DiscoKey: "disco-key-1", - Hostname: "hostname-1", - GivenName: "hostname-1", - UserID: user1.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(machine) - - givenName, err := app.db.GenerateGivenName("machine-key-2", "hostname-2") - comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-2", comment) - - givenName, err = app.db.GenerateGivenName("machine-key-1", "hostname-1") - comment = check.Commentf("Same user, same machine, same hostname, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-1", comment) - - givenName, err = app.db.GenerateGivenName("machine-key-2", "hostname-1") - comment = check.Commentf("Same user, unique machines, same hostname, conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) - - givenName, err = app.db.GenerateGivenName("machine-key-2", "hostname-1") - comment = check.Commentf("Unique users, unique machines, same hostname, conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) -} - -func (s *Suite) TestSetTags(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.NotNil) - - machine := &Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(machine) - - // assign simple tags - sTags := []string{"tag:test", "tag:foo"} - err = app.db.SetTags(machine, sTags, app.UpdateACLRules) - c.Assert(err, check.IsNil) - machine, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.IsNil) - c.Assert(machine.ForcedTags, check.DeepEquals, StringList(sTags)) - - // assign duplicat tags, expect no errors but no doubles in DB - eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} - err = app.db.SetTags(machine, eTags, app.UpdateACLRules) - c.Assert(err, check.IsNil) - machine, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.IsNil) - c.Assert( - machine.ForcedTags, - check.DeepEquals, - StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), - ) -} - -func Test_getTags(t *testing.T) { - type args struct { - aclPolicy *ACLPolicy - machine Machine - stripEmailDomain bool - } - tests := []struct { - name string - args args - wantInvalid []string - wantValid []string - }{ - { - name: "valid tag one machine", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - machine: Machine{ - User: User{ - Name: "joe", - }, - HostInfo: HostInfo{ - RequestTags: []string{"tag:valid"}, - }, - }, - stripEmailDomain: false, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: nil, - }, - { - name: "invalid tag and valid tag one machine", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - machine: Machine{ - User: User{ - Name: "joe", - }, - HostInfo: HostInfo{ - RequestTags: []string{"tag:valid", "tag:invalid"}, - }, - }, - stripEmailDomain: false, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: []string{"tag:invalid"}, - }, - { - name: "multiple invalid and identical tags, should return only one invalid tag", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - machine: Machine{ - User: User{ - Name: "joe", - }, - HostInfo: HostInfo{ - RequestTags: []string{ - "tag:invalid", - "tag:valid", - "tag:invalid", - }, - }, - }, - stripEmailDomain: false, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: []string{"tag:invalid"}, - }, - { - name: "only invalid tags", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - machine: Machine{ - User: User{ - Name: "joe", - }, - HostInfo: HostInfo{ - RequestTags: []string{"tag:invalid", "very-invalid"}, - }, - }, - stripEmailDomain: false, - }, - wantValid: nil, - wantInvalid: []string{"tag:invalid", "very-invalid"}, - }, - { - name: "empty ACLPolicy should return empty tags and should not panic", - args: args{ - aclPolicy: nil, - machine: Machine{ - User: User{ - Name: "joe", - }, - HostInfo: HostInfo{ - RequestTags: []string{"tag:invalid", "very-invalid"}, - }, - }, - stripEmailDomain: false, - }, - wantValid: nil, - wantInvalid: nil, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - gotValid, gotInvalid := getTags( - test.args.aclPolicy, - test.args.machine, - test.args.stripEmailDomain, - ) - for _, valid := range gotValid { - if !util.StringOrPrefixListContains(test.wantValid, valid) { - t.Errorf( - "valids: getTags() = %v, want %v", - gotValid, - test.wantValid, - ) - - break - } - } - for _, invalid := range gotInvalid { - if !util.StringOrPrefixListContains(test.wantInvalid, invalid) { - t.Errorf( - "invalids: getTags() = %v, want %v", - gotInvalid, - test.wantInvalid, - ) - - break - } - } - }) - } -} - -func Test_getFilteredByACLPeers(t *testing.T) { - type args struct { - machines []Machine - rules []tailcfg.FilterRule - machine *Machine - } - tests := []struct { - name string - args args - want Machines - }{ - { - name: "all hosts can talk to each other", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "*"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 1, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - User: User{Name: "joe"}, - }, - }, - want: Machines{ - { - ID: 2, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.3")}, - User: User{Name: "mickael"}, - }, - }, - }, - { - name: "One host can talk to another, but not all hosts", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 1, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - User: User{Name: "joe"}, - }, - }, - want: Machines{ - { - ID: 2, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - User: User{Name: "marc"}, - }, - }, - }, - { - name: "host cannot directly talk to destination, but return path is authorized", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"100.64.0.3"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 2, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - User: User{Name: "marc"}, - }, - }, - want: Machines{ - { - ID: 3, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.3")}, - User: User{Name: "mickael"}, - }, - }, - }, - { - name: "rules allows all hosts to reach one destination", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"*"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - }, - want: Machines{ - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - }, - }, - { - name: "rules allows all hosts to reach one destination, destination can reach all hosts", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"*"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - }, - want: Machines{ - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - }, - { - name: "rule allows all hosts to reach all destinations", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"*"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "*"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 2, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - User: User{Name: "marc"}, - }, - }, - want: Machines{ - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.3")}, - User: User{Name: "mickael"}, - }, - }, - }, - { - name: "without rule all communications are forbidden", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - }, - machine: &Machine{ // current machine - ID: 2, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - User: User{Name: "marc"}, - }, - }, - want: Machines{}, - }, - { - // Investigating 699 - // Found some machines: [ts-head-8w6paa ts-unstable-lys2ib ts-head-upcrmb ts-unstable-rlwpvr] machine=ts-head-8w6paa - // ACL rules generated ACL=[{"DstPorts":[{"Bits":null,"IP":"*","Ports":{"First":0,"Last":65535}}],"SrcIPs":["fd7a:115c:a1e0::3","100.64.0.3","fd7a:115c:a1e0::4","100.64.0.4"]}] - // ACL Cache Map={"100.64.0.3":{"*":{}},"100.64.0.4":{"*":{}},"fd7a:115c:a1e0::3":{"*":{}},"fd7a:115c:a1e0::4":{"*":{}}} - name: "issue-699-broken-star", - args: args{ - machines: Machines{ // - { - ID: 1, - Hostname: "ts-head-upcrmb", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - netip.MustParseAddr("fd7a:115c:a1e0::3"), - }, - User: User{Name: "user1"}, - }, - { - ID: 2, - Hostname: "ts-unstable-rlwpvr", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.4"), - netip.MustParseAddr("fd7a:115c:a1e0::4"), - }, - User: User{Name: "user1"}, - }, - { - ID: 3, - Hostname: "ts-head-8w6paa", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0::1"), - }, - User: User{Name: "user2"}, - }, - { - ID: 4, - Hostname: "ts-unstable-lys2ib", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - netip.MustParseAddr("fd7a:115c:a1e0::2"), - }, - User: User{Name: "user2"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - DstPorts: []tailcfg.NetPortRange{ - { - IP: "*", - Ports: tailcfg.PortRange{First: 0, Last: 65535}, - }, - }, - SrcIPs: []string{ - "fd7a:115c:a1e0::3", "100.64.0.3", - "fd7a:115c:a1e0::4", "100.64.0.4", - }, - }, - }, - machine: &Machine{ // current machine - ID: 3, - Hostname: "ts-head-8w6paa", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0::1"), - }, - User: User{Name: "user2"}, - }, - }, - want: Machines{ - { - ID: 1, - Hostname: "ts-head-upcrmb", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - netip.MustParseAddr("fd7a:115c:a1e0::3"), - }, - User: User{Name: "user1"}, - }, - { - ID: 2, - Hostname: "ts-unstable-rlwpvr", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.4"), - netip.MustParseAddr("fd7a:115c:a1e0::4"), - }, - User: User{Name: "user1"}, - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := filterMachinesByACL( - tt.args.machine, - tt.args.machines, - tt.args.rules, - ) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("filterMachinesByACL() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestHeadscale_generateGivenName(t *testing.T) { - type args struct { - suppliedName string - randomSuffix bool - } - tests := []struct { - name string - db *HSDatabase - args args - want *regexp.Regexp - wantErr bool - }{ - { - name: "simple machine name generation", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "testmachine", - randomSuffix: false, - }, - want: regexp.MustCompile("^testmachine$"), - wantErr: false, - }, - { - name: "machine name with 53 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", - randomSuffix: false, - }, - want: regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"), - wantErr: false, - }, - { - name: "machine name with 63 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", - randomSuffix: false, - }, - want: regexp.MustCompile("^machineeee12345678901234567890123456789012345678901234567890123$"), - wantErr: false, - }, - { - name: "machine name with 64 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", - randomSuffix: false, - }, - want: nil, - wantErr: true, - }, - { - name: "machine name with 73 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", - randomSuffix: false, - }, - want: nil, - wantErr: true, - }, - { - name: "machine name with random suffix", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "test", - randomSuffix: true, - }, - want: regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", MachineGivenNameHashLength)), - wantErr: false, - }, - { - name: "machine name with 63 chars with random suffix", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", - randomSuffix: true, - }, - want: regexp.MustCompile(fmt.Sprintf("^machineeee1234567890123456789012345678901234567890123-[a-z0-9]{%d}$", MachineGivenNameHashLength)), - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) - if (err != nil) != tt.wantErr { - t.Errorf( - "Headscale.GenerateGivenName() error = %v, wantErr %v", - err, - tt.wantErr, - ) - - return - } - - if tt.want != nil && !tt.want.MatchString(got) { - t.Errorf( - "Headscale.GenerateGivenName() = %v, does not match %v", - tt.want, - got, - ) - } - - if len(got) > labelHostnameLength { - t.Errorf( - "Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d", - got, - labelHostnameLength, - ) - } - }) - } -} - -func (s *Suite) TestAutoApproveRoutes(c *check.C) { - acl := []byte(` -{ - "tagOwners": { - "tag:exit": ["test"], - }, - - "groups": { - "group:test": ["test"] - }, - - "acls": [ - {"action": "accept", "users": ["*"], "ports": ["*:*"]}, - ], - - "autoApprovers": { - "exitNode": ["tag:exit"], - "routes": { - "10.10.0.0/16": ["group:test"], - "10.11.0.0/16": ["test"], - } - } -} - `) - - err := app.LoadACLPolicyFromBytes(acl, "hujson") - c.Assert(err, check.IsNil) - - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - nodeKey := key.NewNode() - - defaultRoute := netip.MustParsePrefix("0.0.0.0/0") - route1 := netip.MustParsePrefix("10.10.0.0/16") - // Check if a subprefix of an autoapproved route is approved - route2 := netip.MustParsePrefix("10.11.0.0/24") - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), - DiscoKey: "faa", - Hostname: "test", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo{ - RequestTags: []string{"tag:exit"}, - RoutableIPs: []netip.Prefix{defaultRoute, route1, route2}, - }, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - } - - app.db.db.Save(&machine) - - err = app.db.processMachineRoutes(&machine) - c.Assert(err, check.IsNil) - - machine0ByID, err := app.db.GetMachineByID(0) - c.Assert(err, check.IsNil) - - err = app.db.EnableAutoApprovedRoutes(app.aclPolicy, machine0ByID) - c.Assert(err, check.IsNil) - - enabledRoutes, err := app.db.GetEnabledRoutes(machine0ByID) - c.Assert(err, check.IsNil) - c.Assert(enabledRoutes, check.HasLen, 3) -} - -func TestMachine_canAccess(t *testing.T) { - type args struct { - filter []tailcfg.FilterRule - machine2 *Machine - } - tests := []struct { - name string - machine Machine - args args - want bool - }{ - { - name: "no-rules", - machine: Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.1"), - }, - }, - args: args{ - filter: []tailcfg.FilterRule{}, - machine2: &Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.2"), - }, - }, - }, - want: false, - }, - { - name: "wildcard", - machine: Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.1"), - }, - }, - args: args{ - filter: []tailcfg.FilterRule{ - { - SrcIPs: []string{"*"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "*", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - }, - }, - machine2: &Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.2"), - }, - }, - }, - want: true, - }, - { - name: "explicit-m1-to-m2", - machine: Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.1"), - }, - }, - args: args{ - filter: []tailcfg.FilterRule{ - { - SrcIPs: []string{"10.0.0.1"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "10.0.0.2", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - }, - }, - machine2: &Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.2"), - }, - }, - }, - want: true, - }, - { - name: "explicit-m2-to-m1", - machine: Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.1"), - }, - }, - args: args{ - filter: []tailcfg.FilterRule{ - { - SrcIPs: []string{"10.0.0.2"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "10.0.0.1", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - }, - }, - machine2: &Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.2"), - }, - }, - }, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.machine.canAccess(tt.args.filter, tt.args.machine2); got != tt.want { - t.Errorf("Machine.canAccess() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/hscontrol/matcher.go b/hscontrol/matcher.go deleted file mode 100644 index 3b4670e..0000000 --- a/hscontrol/matcher.go +++ /dev/null @@ -1,142 +0,0 @@ -package hscontrol - -import ( - "fmt" - "net/netip" - "strings" - - "go4.org/netipx" - "tailscale.com/tailcfg" -) - -// This is borrowed from, and updated to use IPSet -// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 -// TODO(kradalby): contribute upstream and make public. -var ( - zeroIP4 = netip.AddrFrom4([4]byte{}) - zeroIP6 = netip.AddrFrom16([16]byte{}) -) - -// parseIPSet parses arg as one: -// -// - an IP address (IPv4 or IPv6) -// - the string "*" to match everything (both IPv4 & IPv6) -// - a CIDR (e.g. "192.168.0.0/16") -// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") -// -// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP -// address (without a slash) treated as a CIDR of *bits length. -// nolint -func parseIPSet(arg string, bits *int) (*netipx.IPSet, error) { - var ipSet netipx.IPSetBuilder - if arg == "*" { - ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0)) - ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0)) - - return ipSet.IPSet() - } - if strings.Contains(arg, "/") { - pfx, err := netip.ParsePrefix(arg) - if err != nil { - return nil, err - } - if pfx != pfx.Masked() { - return nil, fmt.Errorf("%v contains non-network bits set", pfx) - } - - ipSet.AddPrefix(pfx) - - return ipSet.IPSet() - } - if strings.Count(arg, "-") == 1 { - ip1s, ip2s, _ := strings.Cut(arg, "-") - - ip1, err := netip.ParseAddr(ip1s) - if err != nil { - return nil, err - } - - ip2, err := netip.ParseAddr(ip2s) - if err != nil { - return nil, err - } - - r := netipx.IPRangeFrom(ip1, ip2) - if !r.IsValid() { - return nil, fmt.Errorf("invalid IP range %q", arg) - } - - for _, prefix := range r.Prefixes() { - ipSet.AddPrefix(prefix) - } - - return ipSet.IPSet() - } - ip, err := netip.ParseAddr(arg) - if err != nil { - return nil, fmt.Errorf("invalid IP address %q", arg) - } - bits8 := uint8(ip.BitLen()) - if bits != nil { - if *bits < 0 || *bits > int(bits8) { - return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg) - } - bits8 = uint8(*bits) - } - - ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8))) - - return ipSet.IPSet() -} - -type Match struct { - Srcs *netipx.IPSet - Dests *netipx.IPSet -} - -func MatchFromFilterRule(rule tailcfg.FilterRule) Match { - srcs := new(netipx.IPSetBuilder) - dests := new(netipx.IPSetBuilder) - - for _, srcIP := range rule.SrcIPs { - set, _ := parseIPSet(srcIP, nil) - - srcs.AddSet(set) - } - - for _, dest := range rule.DstPorts { - set, _ := parseIPSet(dest.IP, nil) - - dests.AddSet(set) - } - - srcsSet, _ := srcs.IPSet() - destsSet, _ := dests.IPSet() - - match := Match{ - Srcs: srcsSet, - Dests: destsSet, - } - - return match -} - -func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { - for _, ip := range ips { - if m.Srcs.Contains(ip) { - return true - } - } - - return false -} - -func (m *Match) DestsContainsIP(ips []netip.Addr) bool { - for _, ip := range ips { - if m.Dests.Contains(ip) { - return true - } - } - - return false -} diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index c666594..4e68a22 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -14,6 +14,8 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" @@ -638,7 +640,7 @@ func getUserName( claims *IDTokenClaims, stripEmaildomain bool, ) (string, error) { - userName, err := NormalizeToFQDNRules( + userName, err := util.NormalizeToFQDNRules( claims.Email, stripEmaildomain, ) @@ -663,9 +665,9 @@ func getUserName( func (h *Headscale) findOrCreateNewUserForOIDCCallback( writer http.ResponseWriter, userName string, -) (*User, error) { +) (*types.User, error) { user, err := h.db.GetUser(userName) - if errors.Is(err, ErrUserNotFound) { + if errors.Is(err, db.ErrUserNotFound) { user, err = h.db.CreateUser(userName) if err != nil { @@ -709,7 +711,7 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback( func (h *Headscale) registerMachineForOIDCCallback( writer http.ResponseWriter, - user *User, + user *types.User, nodeKey *key.NodePublic, expiry time.Time, ) error { @@ -719,7 +721,7 @@ func (h *Headscale) registerMachineForOIDCCallback( nodeKey.String(), user.Name, &expiry, - RegisterMethodOIDC, + util.RegisterMethodOIDC, ); err != nil { log.Error(). Caller(). diff --git a/hscontrol/acls.go b/hscontrol/policy/acls.go similarity index 79% rename from hscontrol/acls.go rename to hscontrol/policy/acls.go index 2c81046..6b42ebe 100644 --- a/hscontrol/acls.go +++ b/hscontrol/policy/acls.go @@ -1,4 +1,4 @@ -package hscontrol +package policy import ( "encoding/json" @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/tailscale/hujson" @@ -22,12 +23,12 @@ import ( ) var ( - errEmptyPolicy = errors.New("empty policy") - errInvalidAction = errors.New("invalid action") - errInvalidGroup = errors.New("invalid group") - errInvalidTag = errors.New("invalid tag") - errInvalidPortFormat = errors.New("invalid port format") - errWildcardIsNeeded = errors.New("wildcard as port is required for the protocol") + ErrEmptyPolicy = errors.New("empty policy") + ErrInvalidAction = errors.New("invalid action") + ErrInvalidGroup = errors.New("invalid group") + ErrInvalidTag = errors.New("invalid tag") + ErrInvalidPortFormat = errors.New("invalid port format") + ErrWildcardIsNeeded = errors.New("wildcard as port is required for the protocol") ) const ( @@ -56,7 +57,7 @@ const ( var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH") // LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules. -func (h *Headscale) LoadACLPolicyFromPath(path string) error { +func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) { log.Debug(). Str("func", "LoadACLPolicy"). Str("path", path). @@ -64,13 +65,13 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error { policyFile, err := os.Open(path) if err != nil { - return err + return nil, err } defer policyFile.Close() policyBytes, err := io.ReadAll(policyFile) if err != nil { - return err + return nil, err } log.Debug(). @@ -80,90 +81,90 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error { switch filepath.Ext(path) { case ".yml", ".yaml": - return h.LoadACLPolicyFromBytes(policyBytes, "yaml") + return LoadACLPolicyFromBytes(policyBytes, "yaml") } - return h.LoadACLPolicyFromBytes(policyBytes, "hujson") + return LoadACLPolicyFromBytes(policyBytes, "hujson") } -func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error { +func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) { var policy ACLPolicy switch format { case "yaml": err := yaml.Unmarshal(acl, &policy) if err != nil { - return err + return nil, err } default: ast, err := hujson.Parse(acl) if err != nil { - return err + return nil, err } ast.Standardize() acl = ast.Pack() err = json.Unmarshal(acl, &policy) if err != nil { - return err + return nil, err } } if policy.IsZero() { - return errEmptyPolicy + return nil, ErrEmptyPolicy } - h.aclPolicy = &policy - - return h.UpdateACLRules() + return &policy, nil } -func (h *Headscale) UpdateACLRules() error { - machines, err := h.db.ListMachines() - if err != nil { - return err +// TODO(kradalby): This needs to be replace with something that generates +// the rules as needed and not stores it on the global object, rules are +// per node and that should be taken into account. +func GenerateFilterRules( + policy *ACLPolicy, + machines types.Machines, + stripEmailDomain bool, +) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { + if policy == nil { + return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, ErrEmptyPolicy } - if h.aclPolicy == nil { - return errEmptyPolicy - } - - rules, err := h.aclPolicy.generateFilterRules(machines, h.cfg.OIDC.StripEmaildomain) + rules, err := policy.generateFilterRules(machines, stripEmailDomain) if err != nil { - return err + return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } log.Trace().Interface("ACL", rules).Msg("ACL rules generated") - h.aclRules = rules + var sshPolicy *tailcfg.SSHPolicy if featureEnableSSH() { - sshRules, err := h.generateSSHRules() + sshRules, err := generateSSHRules(policy, machines, stripEmailDomain) if err != nil { - return err + return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated") - if h.sshPolicy == nil { - h.sshPolicy = &tailcfg.SSHPolicy{} + if sshPolicy == nil { + sshPolicy = &tailcfg.SSHPolicy{} } - h.sshPolicy.Rules = sshRules - } else if h.aclPolicy != nil && len(h.aclPolicy.SSHs) > 0 { + sshPolicy.Rules = sshRules + } else if policy != nil && len(policy.SSHs) > 0 { log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating") } - return nil + return rules, sshPolicy, nil } // generateFilterRules takes a set of machines and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *ACLPolicy) generateFilterRules( - machines []Machine, + machines types.Machines, stripEmailDomain bool, ) ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} for index, acl := range pol.ACLs { if acl.Action != "accept" { - return nil, errInvalidAction + return nil, ErrInvalidAction } srcIPs := []string{} @@ -219,16 +220,15 @@ func (pol *ACLPolicy) generateFilterRules( return rules, nil } -func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { +func generateSSHRules( + policy *ACLPolicy, + machines types.Machines, + stripEmailDomain bool, +) ([]*tailcfg.SSHRule, error) { rules := []*tailcfg.SSHRule{} - if h.aclPolicy == nil { - return nil, errEmptyPolicy - } - - machines, err := h.db.ListMachines() - if err != nil { - return nil, err + if policy == nil { + return nil, ErrEmptyPolicy } acceptAction := tailcfg.SSHAction{ @@ -251,7 +251,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { AllowLocalPortForwarding: false, } - for index, sshACL := range h.aclPolicy.SSHs { + for index, sshACL := range policy.SSHs { action := rejectAction switch sshACL.Action { case "accept": @@ -266,9 +266,9 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { } default: log.Error(). - Msgf("Error parsing SSH %d, unknown action '%s'", index, sshACL.Action) + Msgf("Error parsing SSH %d, unknown action '%s', skipping", index, sshACL.Action) - return nil, err + continue } principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) @@ -278,7 +278,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { Any: true, }) } else if isGroup(rawSrc) { - users, err := h.aclPolicy.getUsersInGroup(rawSrc, h.cfg.OIDC.StripEmaildomain) + users, err := policy.getUsersInGroup(rawSrc, stripEmailDomain) if err != nil { log.Error(). Msgf("Error parsing SSH %d, Source %d", index, innerIndex) @@ -292,10 +292,10 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { }) } } else { - expandedSrcs, err := h.aclPolicy.expandAlias( + expandedSrcs, err := policy.ExpandAlias( machines, rawSrc, - h.cfg.OIDC.StripEmaildomain, + stripEmailDomain, ) if err != nil { log.Error(). @@ -346,10 +346,10 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { // with the given src alias. func (pol *ACLPolicy) getIPsFromSource( src string, - machines []Machine, + machines types.Machines, stripEmaildomain bool, ) ([]string, error) { - ipSet, err := pol.expandAlias(machines, src, stripEmaildomain) + ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain) if err != nil { return []string{}, err } @@ -367,7 +367,7 @@ func (pol *ACLPolicy) getIPsFromSource( // which are associated with the dest alias. func (pol *ACLPolicy) getNetPortRangeFromDestination( dest string, - machines []Machine, + machines types.Machines, needsWildcard bool, stripEmaildomain bool, ) ([]tailcfg.NetPortRange, error) { @@ -390,7 +390,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( return nil, fmt.Errorf( "failed to parse destination, tokens %v: %w", tokens, - errInvalidPortFormat, + ErrInvalidPortFormat, ) } else { tokens = []string{maybeIPv6Str, port} @@ -414,7 +414,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) } - expanded, err := pol.expandAlias( + expanded, err := pol.ExpandAlias( machines, alias, stripEmaildomain, @@ -499,13 +499,13 @@ func parseProtocol(protocol string) ([]int, bool, error) { // - an ip // - a cidr // and transform these in IPAddresses. -func (pol *ACLPolicy) expandAlias( - machines Machines, +func (pol *ACLPolicy) ExpandAlias( + machines types.Machines, alias string, stripEmailDomain bool, ) (*netipx.IPSet, error) { if isWildcard(alias) { - return parseIPSet("*", nil) + return util.ParseIPSet("*", nil) } build := netipx.IPSetBuilder{} @@ -532,9 +532,9 @@ func (pol *ACLPolicy) expandAlias( // if alias is an host // Note, this is recursive. if h, ok := pol.Hosts[alias]; ok { - log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry") + log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") - return pol.expandAlias(machines, h.String(), stripEmailDomain) + return pol.ExpandAlias(machines, h.String(), stripEmailDomain) } // if alias is an IP @@ -557,11 +557,11 @@ func (pol *ACLPolicy) expandAlias( // we assume in this function that we only have nodes from 1 user. func excludeCorrectlyTaggedNodes( aclPolicy *ACLPolicy, - nodes []Machine, + nodes types.Machines, user string, stripEmailDomain bool, -) []Machine { - out := []Machine{} +) types.Machines { + out := types.Machines{} tags := []string{} for tag := range aclPolicy.TagOwners { owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain) @@ -601,7 +601,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err } if needsWildcard { - return nil, errWildcardIsNeeded + return nil, ErrWildcardIsNeeded } ports := []tailcfg.PortRange{} @@ -634,15 +634,15 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err }) default: - return nil, errInvalidPortFormat + return nil, ErrInvalidPortFormat } } return &ports, nil } -func filterMachinesByUser(machines []Machine, user string) []Machine { - out := []Machine{} +func filterMachinesByUser(machines types.Machines, user string) types.Machines { + out := types.Machines{} for _, machine := range machines { if machine.User.Name == user { out = append(out, machine) @@ -664,7 +664,7 @@ func getTagOwners( if !ok { return []string{}, fmt.Errorf( "%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", - errInvalidTag, + ErrInvalidTag, tag, ) } @@ -696,22 +696,22 @@ func (pol *ACLPolicy) getUsersInGroup( return []string{}, fmt.Errorf( "group %v isn't registered. %w", group, - errInvalidGroup, + ErrInvalidGroup, ) } for _, group := range aclGroups { if isGroup(group) { return []string{}, fmt.Errorf( "%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", - errInvalidGroup, + ErrInvalidGroup, ) } - grp, err := NormalizeToFQDNRules(group, stripEmailDomain) + grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain) if err != nil { return []string{}, fmt.Errorf( "failed to normalize group %q, err: %w", group, - errInvalidGroup, + ErrInvalidGroup, ) } users = append(users, grp) @@ -722,7 +722,7 @@ func (pol *ACLPolicy) getUsersInGroup( func (pol *ACLPolicy) getIPsFromGroup( group string, - machines Machines, + machines types.Machines, stripEmailDomain bool, ) (*netipx.IPSet, error) { build := netipx.IPSetBuilder{} @@ -743,7 +743,7 @@ func (pol *ACLPolicy) getIPsFromGroup( func (pol *ACLPolicy) getIPsFromTag( alias string, - machines Machines, + machines types.Machines, stripEmailDomain bool, ) (*netipx.IPSet, error) { build := netipx.IPSetBuilder{} @@ -758,12 +758,12 @@ func (pol *ACLPolicy) getIPsFromTag( // find tag owners owners, err := getTagOwners(pol, alias, stripEmailDomain) if err != nil { - if errors.Is(err, errInvalidTag) { + if errors.Is(err, ErrInvalidTag) { ipSet, _ := build.IPSet() if len(ipSet.Prefixes()) == 0 { return ipSet, fmt.Errorf( "%w. %v isn't owned by a TagOwner and no forced tags are defined", - errInvalidTag, + ErrInvalidTag, alias, ) } @@ -790,7 +790,7 @@ func (pol *ACLPolicy) getIPsFromTag( func (pol *ACLPolicy) getIPsForUser( user string, - machines Machines, + machines types.Machines, stripEmailDomain bool, ) (*netipx.IPSet, error) { build := netipx.IPSetBuilder{} @@ -812,9 +812,9 @@ func (pol *ACLPolicy) getIPsForUser( func (pol *ACLPolicy) getIPsFromSingleIP( ip netip.Addr, - machines Machines, + machines types.Machines, ) (*netipx.IPSet, error) { - log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip") + log.Trace().Str("ip", ip.String()).Msg("ExpandAlias got ip") matches := machines.FilterByIP(ip) @@ -830,7 +830,7 @@ func (pol *ACLPolicy) getIPsFromSingleIP( func (pol *ACLPolicy) getIPsFromIPPrefix( prefix netip.Prefix, - machines Machines, + machines types.Machines, ) (*netipx.IPSet, error) { log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix") build := netipx.IPSetBuilder{} @@ -862,3 +862,65 @@ func isGroup(str string) bool { func isTag(str string) bool { return strings.HasPrefix(str, "tag:") } + +// getTags will return the tags of the current machine. +// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. +// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. +func (pol *ACLPolicy) GetTagsOfMachine( + machine types.Machine, + stripEmailDomain bool, +) ([]string, []string) { + validTags := make([]string, 0) + invalidTags := make([]string, 0) + + validTagMap := make(map[string]bool) + invalidTagMap := make(map[string]bool) + for _, tag := range machine.HostInfo.RequestTags { + owners, err := getTagOwners(pol, tag, stripEmailDomain) + if errors.Is(err, ErrInvalidTag) { + invalidTagMap[tag] = true + + continue + } + var found bool + for _, owner := range owners { + if machine.User.Name == owner { + found = true + } + } + if found { + validTagMap[tag] = true + } else { + invalidTagMap[tag] = true + } + } + for tag := range invalidTagMap { + invalidTags = append(invalidTags, tag) + } + for tag := range validTagMap { + validTags = append(validTags, tag) + } + + return validTags, invalidTags +} + +// FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine. +func FilterMachinesByACL( + machine *types.Machine, + machines types.Machines, + filter []tailcfg.FilterRule, +) types.Machines { + result := types.Machines{} + + for index, peer := range machines { + if peer.ID == machine.ID { + continue + } + + if machine.CanAccess(filter, &machines[index]) || peer.CanAccess(filter, machine) { + result = append(result, peer) + } + } + + return result +} diff --git a/hscontrol/acls_test.go b/hscontrol/policy/acls_test.go similarity index 56% rename from hscontrol/acls_test.go rename to hscontrol/policy/acls_test.go index 70a57b8..f6c5e10 100644 --- a/hscontrol/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -1,4 +1,4 @@ -package hscontrol +package policy import ( "errors" @@ -7,15 +7,24 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "go4.org/netipx" "gopkg.in/check.v1" - "tailscale.com/envknob" "tailscale.com/tailcfg" ) +func Test(t *testing.T) { + check.TestingT(t) +} + +var _ = check.Suite(&Suite{}) + +type Suite struct{} + func (s *Suite) TestWrongPath(c *check.C) { - err := app.LoadACLPolicyFromPath("asdfg") + _, err := LoadACLPolicyFromPath("asdfg") c.Assert(err, check.NotNil) } @@ -23,7 +32,7 @@ func (s *Suite) TestBrokenHuJson(c *check.C) { acl := []byte(` { `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + _, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.NotNil) } @@ -34,9 +43,9 @@ func (s *Suite) TestInvalidPolicyHuson(c *check.C) { "but_a_policy_though": false } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + _, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.NotNil) - c.Assert(err, check.Equals, errEmptyPolicy) + c.Assert(err, check.Equals, ErrEmptyPolicy) } func (s *Suite) TestParseHosts(c *check.C) { @@ -185,8 +194,13 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { ], } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + pol, err := LoadACLPolicyFromBytes(acl, "hujson") + c.Assert(pol.ACLs, check.HasLen, 6) + c.Assert(err, check.IsNil) + + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.NotNil) + c.Assert(rules, check.IsNil) } func (s *Suite) TestBasicRule(c *check.C) { @@ -212,17 +226,17 @@ func (s *Suite) TestBasicRule(c *check.C) { ], } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + pol, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) - rules, err := app.aclPolicy.generateFilterRules([]Machine{}, false) + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) } // TODO(kradalby): Make tests values safe, independent and descriptive. func (s *Suite) TestInvalidAction(c *check.C) { - app.aclPolicy = &ACLPolicy{ + pol := &ACLPolicy{ ACLs: []ACL{ { Action: "invalidAction", @@ -231,88 +245,13 @@ func (s *Suite) TestInvalidAction(c *check.C) { }, }, } - err := app.UpdateACLRules() - c.Assert(errors.Is(err, errInvalidAction), check.Equals, true) -} - -func (s *Suite) TestSshRules(c *check.C) { - envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1") - - user, err := app.db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user1", "testmachine") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testmachine", - RequestTags: []string{"tag:test"}, - } - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), - } - app.db.db.Save(&machine) - - app.aclPolicy = &ACLPolicy{ - Groups: Groups{ - "group:test": []string{"user1"}, - }, - Hosts: Hosts{ - "client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32), - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, - }, - }, - SSHs: []SSH{ - { - Action: "accept", - Sources: []string{"group:test"}, - Destinations: []string{"client"}, - Users: []string{"autogroup:nonroot"}, - }, - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"client"}, - Users: []string{"autogroup:nonroot"}, - }, - }, - } - - err = app.UpdateACLRules() - - c.Assert(err, check.IsNil) - c.Assert(app.sshPolicy, check.NotNil) - c.Assert(app.sshPolicy.Rules, check.HasLen, 2) - c.Assert(app.sshPolicy.Rules[0].SSHUsers, check.HasLen, 1) - c.Assert(app.sshPolicy.Rules[0].Principals, check.HasLen, 1) - c.Assert(app.sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1") - - c.Assert(app.sshPolicy.Rules[1].SSHUsers, check.HasLen, 1) - c.Assert(app.sshPolicy.Rules[1].Principals, check.HasLen, 1) - c.Assert(app.sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*") + _, _, err := GenerateFilterRules(pol, types.Machines{}, false) + c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) } func (s *Suite) TestInvalidGroupInGroup(c *check.C) { // this ACL is wrong because the group in Sources sections doesn't exist - app.aclPolicy = &ACLPolicy{ + pol := &ACLPolicy{ Groups: Groups{ "group:test": []string{"foo"}, "group:error": []string{"foo", "group:test"}, @@ -325,13 +264,13 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { }, }, } - err := app.UpdateACLRules() - c.Assert(errors.Is(err, errInvalidGroup), check.Equals, true) + _, _, err := GenerateFilterRules(pol, types.Machines{}, false) + c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) } func (s *Suite) TestInvalidTagOwners(c *check.C) { // this ACL is wrong because no tagOwners own the requested tag for the server - app.aclPolicy = &ACLPolicy{ + pol := &ACLPolicy{ ACLs: []ACL{ { Action: "accept", @@ -340,232 +279,9 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { }, }, } - err := app.UpdateACLRules() - c.Assert(errors.Is(err, errInvalidTag), check.Equals, true) -} -// this test should validate that we can expand a group in a TagOWner section and -// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. -// the tag is matched in the Sources section. -func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { - user, err := app.db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user1", "testmachine") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testmachine", - RequestTags: []string{"tag:test"}, - } - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), - } - app.db.db.Save(&machine) - - app.aclPolicy = &ACLPolicy{ - Groups: Groups{"group:test": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"tag:test"}, - Destinations: []string{"*:*"}, - }, - }, - } - err = app.UpdateACLRules() - c.Assert(err, check.IsNil) - c.Assert(app.aclRules, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") -} - -// this test should validate that we can expand a group in a TagOWner section and -// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. -// the tag is matched in the Destinations section. -func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { - user, err := app.db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user1", "testmachine") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testmachine", - RequestTags: []string{"tag:test"}, - } - - machine := Machine{ - ID: 1, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), - } - app.db.db.Save(&machine) - - app.aclPolicy = &ACLPolicy{ - Groups: Groups{"group:test": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"tag:test:*"}, - }, - }, - } - err = app.UpdateACLRules() - c.Assert(err, check.IsNil) - c.Assert(app.aclRules, check.HasLen, 1) - c.Assert(app.aclRules[0].DstPorts, check.HasLen, 1) - c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") -} - -// need a test with: -// tag on a host that isn't owned by a tag owners. So the user -// of the host should be valid. -func (s *Suite) TestInvalidTagValidUser(c *check.C) { - user, err := app.db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user1", "testmachine") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testmachine", - RequestTags: []string{"tag:foo"}, - } - - machine := Machine{ - ID: 1, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), - } - app.db.db.Save(&machine) - - app.aclPolicy = &ACLPolicy{ - TagOwners: TagOwners{"tag:test": []string{"user1"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"*:*"}, - }, - }, - } - err = app.UpdateACLRules() - c.Assert(err, check.IsNil) - c.Assert(app.aclRules, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") -} - -// tag on a host is owned by a tag owner, the tag is valid. -// an ACL rule is matching the tag to a user. It should not be valid since the -// host should be tied to the tag now. -func (s *Suite) TestValidTagInvalidUser(c *check.C) { - user, err := app.db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user1", "webserver") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "webserver", - RequestTags: []string{"tag:webapp"}, - } - - machine := Machine{ - ID: 1, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "webserver", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), - } - app.db.db.Save(&machine) - _, err = app.db.GetMachine("user1", "user") - hostInfo2 := tailcfg.Hostinfo{ - OS: "debian", - Hostname: "Hostname", - } - c.Assert(err, check.NotNil) - machine = Machine{ - ID: 2, - MachineKey: "56789", - NodeKey: "bar2", - DiscoKey: "faab", - Hostname: "user", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo2), - } - app.db.db.Save(&machine) - - app.aclPolicy = &ACLPolicy{ - TagOwners: TagOwners{"tag:webapp": []string{"user1"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"tag:webapp:80,443"}, - }, - }, - } - err = app.UpdateACLRules() - c.Assert(err, check.IsNil) - c.Assert(app.aclRules, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.2/32") - c.Assert(app.aclRules[0].DstPorts, check.HasLen, 2) - c.Assert(app.aclRules[0].DstPorts[0].Ports.First, check.Equals, uint16(80)) - c.Assert(app.aclRules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80)) - c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - c.Assert(app.aclRules[0].DstPorts[1].Ports.First, check.Equals, uint16(443)) - c.Assert(app.aclRules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443)) - c.Assert(app.aclRules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32") + _, _, err := GenerateFilterRules(pol, types.Machines{}, false) + c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) } func (s *Suite) TestPortRange(c *check.C) { @@ -589,10 +305,11 @@ func (s *Suite) TestPortRange(c *check.C) { ], } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + pol, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) - rules, err := app.aclPolicy.generateFilterRules([]Machine{}, false) + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -644,10 +361,11 @@ func (s *Suite) TestProtocolParsing(c *check.C) { ], } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + pol, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) - rules, err := app.aclPolicy.generateFilterRules([]Machine{}, false) + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -678,10 +396,11 @@ func (s *Suite) TestPortWildcard(c *check.C) { ], } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + pol, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) - rules, err := app.aclPolicy.generateFilterRules([]Machine{}, false) + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -704,10 +423,11 @@ acls: - "*" dst: - host-1:*`) - err := app.LoadACLPolicyFromBytes(acl, "yaml") + pol, err := LoadACLPolicyFromBytes(acl, "yaml") c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) - rules, err := app.aclPolicy.generateFilterRules([]Machine{}, false) + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -719,138 +439,6 @@ acls: c.Assert(rules[0].SrcIPs[0], check.Equals, "0.0.0.0/0") } -func (s *Suite) TestPortUser(c *check.C) { - user, err := app.db.CreateUser("testuser") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("testuser", "testmachine") - c.Assert(err, check.NotNil) - ips, _ := app.db.getAvailableIPs() - machine := Machine{ - ID: 0, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: ips, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - acl := []byte(` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "testuser", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} - `) - err = app.LoadACLPolicyFromBytes(acl, "hujson") - c.Assert(err, check.IsNil) - - machines, err := app.db.ListMachines() - c.Assert(err, check.IsNil) - - rules, err := app.aclPolicy.generateFilterRules(machines, false) - c.Assert(err, check.IsNil) - c.Assert(rules, check.NotNil) - - c.Assert(rules, check.HasLen, 1) - c.Assert(rules[0].DstPorts, check.HasLen, 1) - c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) - c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) - c.Assert(rules[0].SrcIPs, check.HasLen, 1) - c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") - c.Assert(len(ips), check.Equals, 1) - c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") -} - -func (s *Suite) TestPortGroup(c *check.C) { - user, err := app.db.CreateUser("testuser") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("testuser", "testmachine") - c.Assert(err, check.NotNil) - ips, _ := app.db.getAvailableIPs() - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: ips, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - acl := []byte(` -{ - "groups": { - "group:example": [ - "testuser", - ], - }, - - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "group:example", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} - `) - err = app.LoadACLPolicyFromBytes(acl, "hujson") - c.Assert(err, check.IsNil) - - machines, err := app.db.ListMachines() - c.Assert(err, check.IsNil) - - rules, err := app.aclPolicy.generateFilterRules(machines, false) - c.Assert(err, check.IsNil) - c.Assert(rules, check.NotNil) - - c.Assert(rules, check.HasLen, 1) - c.Assert(rules[0].DstPorts, check.HasLen, 1) - c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) - c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) - c.Assert(rules[0].SrcIPs, check.HasLen, 1) - c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") - c.Assert(len(ips), check.Equals, 1) - c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") -} - func Test_expandGroup(t *testing.T) { type field struct { pol ACLPolicy @@ -1151,54 +739,54 @@ func Test_expandPorts(t *testing.T) { func Test_listMachinesInUser(t *testing.T) { type args struct { - machines []Machine + machines types.Machines user string } tests := []struct { name string args args - want []Machine + want types.Machines }{ { name: "1 machine in user", args: args{ - machines: []Machine{ - {User: User{Name: "joe"}}, + machines: types.Machines{ + {User: types.User{Name: "joe"}}, }, user: "joe", }, - want: []Machine{ - {User: User{Name: "joe"}}, + want: types.Machines{ + {User: types.User{Name: "joe"}}, }, }, { name: "3 machines, 2 in user", args: args{ - machines: []Machine{ - {ID: 1, User: User{Name: "joe"}}, - {ID: 2, User: User{Name: "marc"}}, - {ID: 3, User: User{Name: "marc"}}, + machines: types.Machines{ + {ID: 1, User: types.User{Name: "joe"}}, + {ID: 2, User: types.User{Name: "marc"}}, + {ID: 3, User: types.User{Name: "marc"}}, }, user: "marc", }, - want: []Machine{ - {ID: 2, User: User{Name: "marc"}}, - {ID: 3, User: User{Name: "marc"}}, + want: types.Machines{ + {ID: 2, User: types.User{Name: "marc"}}, + {ID: 3, User: types.User{Name: "marc"}}, }, }, { name: "5 machines, 0 in user", args: args{ - machines: []Machine{ - {ID: 1, User: User{Name: "joe"}}, - {ID: 2, User: User{Name: "marc"}}, - {ID: 3, User: User{Name: "marc"}}, - {ID: 4, User: User{Name: "marc"}}, - {ID: 5, User: User{Name: "marc"}}, + machines: types.Machines{ + {ID: 1, User: types.User{Name: "joe"}}, + {ID: 2, User: types.User{Name: "marc"}}, + {ID: 3, User: types.User{Name: "marc"}}, + {ID: 4, User: types.User{Name: "marc"}}, + {ID: 5, User: types.User{Name: "marc"}}, }, user: "mickael", }, - want: []Machine{}, + want: types.Machines{}, }, } for _, test := range tests { @@ -1234,7 +822,7 @@ func Test_expandAlias(t *testing.T) { pol ACLPolicy } type args struct { - machines []Machine + machines types.Machines aclPolicy ACLPolicy alias string stripEmailDomain bool @@ -1253,10 +841,10 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "*", - machines: []Machine{ - {IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}}, + machines: types.Machines{ + {IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}}, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.78.84.227"), }, }, @@ -1278,30 +866,30 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "group:accountant", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1320,30 +908,30 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "group:hr", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1358,7 +946,7 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "10.0.0.3", - machines: []Machine{}, + machines: types.Machines{}, stripEmailDomain: true, }, want: set([]string{ @@ -1373,7 +961,7 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "10.0.0.1", - machines: []Machine{}, + machines: types.Machines{}, stripEmailDomain: true, }, want: set([]string{ @@ -1388,12 +976,12 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "10.0.0.1", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("10.0.0.1"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1410,13 +998,13 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "10.0.0.1", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1433,13 +1021,13 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1460,7 +1048,7 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "testy", - machines: []Machine{}, + machines: types.Machines{}, stripEmailDomain: true, }, want: set([]string{}, []string{"10.0.0.132/32"}), @@ -1477,7 +1065,7 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "homeNetwork", - machines: []Machine{}, + machines: types.Machines{}, stripEmailDomain: true, }, want: set([]string{}, []string{"192.168.1.0/24"}), @@ -1490,7 +1078,7 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "10.0.0.0/16", - machines: []Machine{}, + machines: types.Machines{}, aclPolicy: ACLPolicy{}, stripEmailDomain: true, }, @@ -1506,40 +1094,40 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "tag:hr-webserver", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, stripEmailDomain: true, @@ -1561,30 +1149,30 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "tag:hr-webserver", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1599,32 +1187,32 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "tag:hr-webserver", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, ForcedTags: []string{"tag:hr-webserver"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, ForcedTags: []string{"tag:hr-webserver"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1643,36 +1231,36 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "tag:hr-webserver", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, ForcedTags: []string{"tag:hr-webserver"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1689,40 +1277,40 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "joe", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, stripEmailDomain: true, @@ -1733,7 +1321,7 @@ func Test_expandAlias(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.field.pol.expandAlias( + got, err := test.field.pol.ExpandAlias( test.args.machines, test.args.alias, test.args.stripEmailDomain, @@ -1753,14 +1341,14 @@ func Test_expandAlias(t *testing.T) { func Test_excludeCorrectlyTaggedNodes(t *testing.T) { type args struct { aclPolicy *ACLPolicy - nodes []Machine + nodes types.Machines user string stripEmailDomain bool } tests := []struct { name string args args - want []Machine + want types.Machines wantErr bool }{ { @@ -1769,43 +1357,43 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { aclPolicy: &ACLPolicy{ TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, }, - nodes: []Machine{ + nodes: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, user: "joe", stripEmailDomain: true, }, - want: []Machine{ + want: types.Machines{ { - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.4")}, - User: User{Name: "joe"}, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.4")}, + User: types.User{Name: "joe"}, }, }, }, @@ -1820,43 +1408,43 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { "tag:accountant-webserver": []string{"group:accountant"}, }, }, - nodes: []Machine{ + nodes: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, user: "joe", stripEmailDomain: true, }, - want: []Machine{ + want: types.Machines{ { - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.4")}, - User: User{Name: "joe"}, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.4")}, + User: types.User{Name: "joe"}, }, }, }, @@ -1866,39 +1454,39 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { aclPolicy: &ACLPolicy{ TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, }, - nodes: []Machine{ + nodes: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, ForcedTags: []string{"tag:accountant-webserver"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, user: "joe", stripEmailDomain: true, }, - want: []Machine{ + want: types.Machines{ { - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.4")}, - User: User{Name: "joe"}, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.4")}, + User: types.User{Name: "joe"}, }, }, }, @@ -1908,67 +1496,67 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { aclPolicy: &ACLPolicy{ TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, }, - nodes: []Machine{ + nodes: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "hr-web1", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "hr-web2", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, user: "joe", stripEmailDomain: true, }, - want: []Machine{ + want: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "hr-web1", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "hr-web2", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, }, @@ -1993,7 +1581,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { pol ACLPolicy } type args struct { - machines []Machine + machines types.Machines stripEmailDomain bool } tests := []struct { @@ -2024,7 +1612,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, }, args: args{ - machines: []Machine{}, + machines: types.Machines{}, stripEmailDomain: true, }, want: []tailcfg.FilterRule{ @@ -2064,27 +1652,30 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, }, args: args{ - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, }, want: []tailcfg.FilterRule{ { - SrcIPs: []string{"100.64.0.1/32", "fd7a:115c:a1e0:ab12:4843:2222:6273:2221/128"}, + SrcIPs: []string{ + "100.64.0.1/32", + "fd7a:115c:a1e0:ab12:4843:2222:6273:2221/128", + }, DstPorts: []tailcfg.NetPortRange{ { IP: "100.64.0.2/32", @@ -2113,14 +1704,631 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { tt.args.stripEmailDomain, ) if (err != nil) != tt.wantErr { - t.Errorf("ACLPolicy.generateFilterRules() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr) return } if diff := cmp.Diff(tt.want, got); diff != "" { log.Trace().Interface("got", got).Msg("result") - t.Errorf("ACLPolicy.generateFilterRules() = %v, want %v", got, tt.want) + t.Errorf("ACLgenerateFilterRules() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getTags(t *testing.T) { + type args struct { + aclPolicy *ACLPolicy + machine types.Machine + stripEmailDomain bool + } + tests := []struct { + name string + args args + wantInvalid []string + wantValid []string + }{ + { + name: "valid tag one machine", + args: args{ + aclPolicy: &ACLPolicy{ + TagOwners: TagOwners{ + "tag:valid": []string{"joe"}, + }, + }, + machine: types.Machine{ + User: types.User{ + Name: "joe", + }, + HostInfo: types.HostInfo{ + RequestTags: []string{"tag:valid"}, + }, + }, + stripEmailDomain: false, + }, + wantValid: []string{"tag:valid"}, + wantInvalid: nil, + }, + { + name: "invalid tag and valid tag one machine", + args: args{ + aclPolicy: &ACLPolicy{ + TagOwners: TagOwners{ + "tag:valid": []string{"joe"}, + }, + }, + machine: types.Machine{ + User: types.User{ + Name: "joe", + }, + HostInfo: types.HostInfo{ + RequestTags: []string{"tag:valid", "tag:invalid"}, + }, + }, + stripEmailDomain: false, + }, + wantValid: []string{"tag:valid"}, + wantInvalid: []string{"tag:invalid"}, + }, + { + name: "multiple invalid and identical tags, should return only one invalid tag", + args: args{ + aclPolicy: &ACLPolicy{ + TagOwners: TagOwners{ + "tag:valid": []string{"joe"}, + }, + }, + machine: types.Machine{ + User: types.User{ + Name: "joe", + }, + HostInfo: types.HostInfo{ + RequestTags: []string{ + "tag:invalid", + "tag:valid", + "tag:invalid", + }, + }, + }, + stripEmailDomain: false, + }, + wantValid: []string{"tag:valid"}, + wantInvalid: []string{"tag:invalid"}, + }, + { + name: "only invalid tags", + args: args{ + aclPolicy: &ACLPolicy{ + TagOwners: TagOwners{ + "tag:valid": []string{"joe"}, + }, + }, + machine: types.Machine{ + User: types.User{ + Name: "joe", + }, + HostInfo: types.HostInfo{ + RequestTags: []string{"tag:invalid", "very-invalid"}, + }, + }, + stripEmailDomain: false, + }, + wantValid: nil, + wantInvalid: []string{"tag:invalid", "very-invalid"}, + }, + { + name: "empty ACLPolicy should return empty tags and should not panic", + args: args{ + aclPolicy: &ACLPolicy{}, + machine: types.Machine{ + User: types.User{ + Name: "joe", + }, + HostInfo: types.HostInfo{ + RequestTags: []string{"tag:invalid", "very-invalid"}, + }, + }, + stripEmailDomain: false, + }, + wantValid: nil, + wantInvalid: []string{"tag:invalid", "very-invalid"}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine( + test.args.machine, + test.args.stripEmailDomain, + ) + for _, valid := range gotValid { + if !util.StringOrPrefixListContains(test.wantValid, valid) { + t.Errorf( + "valids: getTags() = %v, want %v", + gotValid, + test.wantValid, + ) + + break + } + } + for _, invalid := range gotInvalid { + if !util.StringOrPrefixListContains(test.wantInvalid, invalid) { + t.Errorf( + "invalids: getTags() = %v, want %v", + gotInvalid, + test.wantInvalid, + ) + + break + } + } + }) + } +} + +func Test_getFilteredByACLPeers(t *testing.T) { + type args struct { + machines types.Machines + rules []tailcfg.FilterRule + machine *types.Machine + } + tests := []struct { + name string + args args + want types.Machines + }{ + { + name: "all hosts can talk to each other", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 1, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + User: types.User{Name: "joe"}, + }, + }, + want: types.Machines{ + { + ID: 2, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.3")}, + User: types.User{Name: "mickael"}, + }, + }, + }, + { + name: "One host can talk to another, but not all hosts", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 1, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + User: types.User{Name: "joe"}, + }, + }, + want: types.Machines{ + { + ID: 2, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + User: types.User{Name: "marc"}, + }, + }, + }, + { + name: "host cannot directly talk to destination, but return path is authorized", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 2, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + User: types.User{Name: "marc"}, + }, + }, + want: types.Machines{ + { + ID: 3, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.3")}, + User: types.User{Name: "mickael"}, + }, + }, + }, + { + name: "rules allows all hosts to reach one destination", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + }, + want: types.Machines{ + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + }, + }, + { + name: "rules allows all hosts to reach one destination, destination can reach all hosts", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + }, + want: types.Machines{ + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + }, + { + name: "rule allows all hosts to reach all destinations", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 2, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + User: types.User{Name: "marc"}, + }, + }, + want: types.Machines{ + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.3")}, + User: types.User{Name: "mickael"}, + }, + }, + }, + { + name: "without rule all communications are forbidden", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + }, + machine: &types.Machine{ // current machine + ID: 2, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + User: types.User{Name: "marc"}, + }, + }, + want: types.Machines{}, + }, + { + // Investigating 699 + // Found some machines: [ts-head-8w6paa ts-unstable-lys2ib ts-head-upcrmb ts-unstable-rlwpvr] machine=ts-head-8w6paa + // ACL rules generated ACL=[{"DstPorts":[{"Bits":null,"IP":"*","Ports":{"First":0,"Last":65535}}],"SrcIPs":["fd7a:115c:a1e0::3","100.64.0.3","fd7a:115c:a1e0::4","100.64.0.4"]}] + // ACL Cache Map={"100.64.0.3":{"*":{}},"100.64.0.4":{"*":{}},"fd7a:115c:a1e0::3":{"*":{}},"fd7a:115c:a1e0::4":{"*":{}}} + name: "issue-699-broken-star", + args: args{ + machines: types.Machines{ // + { + ID: 1, + Hostname: "ts-head-upcrmb", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + netip.MustParseAddr("fd7a:115c:a1e0::3"), + }, + User: types.User{Name: "user1"}, + }, + { + ID: 2, + Hostname: "ts-unstable-rlwpvr", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.4"), + netip.MustParseAddr("fd7a:115c:a1e0::4"), + }, + User: types.User{Name: "user1"}, + }, + { + ID: 3, + Hostname: "ts-head-8w6paa", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + netip.MustParseAddr("fd7a:115c:a1e0::1"), + }, + User: types.User{Name: "user2"}, + }, + { + ID: 4, + Hostname: "ts-unstable-lys2ib", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + netip.MustParseAddr("fd7a:115c:a1e0::2"), + }, + User: types.User{Name: "user2"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + DstPorts: []tailcfg.NetPortRange{ + { + IP: "*", + Ports: tailcfg.PortRange{First: 0, Last: 65535}, + }, + }, + SrcIPs: []string{ + "fd7a:115c:a1e0::3", "100.64.0.3", + "fd7a:115c:a1e0::4", "100.64.0.4", + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 3, + Hostname: "ts-head-8w6paa", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + netip.MustParseAddr("fd7a:115c:a1e0::1"), + }, + User: types.User{Name: "user2"}, + }, + }, + want: types.Machines{ + { + ID: 1, + Hostname: "ts-head-upcrmb", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + netip.MustParseAddr("fd7a:115c:a1e0::3"), + }, + User: types.User{Name: "user1"}, + }, + { + ID: 2, + Hostname: "ts-unstable-rlwpvr", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.4"), + netip.MustParseAddr("fd7a:115c:a1e0::4"), + }, + User: types.User{Name: "user1"}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FilterMachinesByACL( + tt.args.machine, + tt.args.machines, + tt.args.rules, + ) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("filterMachinesByACL() = %v, want %v", got, tt.want) } }) } diff --git a/hscontrol/acls_types.go b/hscontrol/policy/acls_types.go similarity index 99% rename from hscontrol/acls_types.go rename to hscontrol/policy/acls_types.go index 0e55351..e9c4490 100644 --- a/hscontrol/acls_types.go +++ b/hscontrol/policy/acls_types.go @@ -1,4 +1,4 @@ -package hscontrol +package policy import ( "encoding/json" diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go new file mode 100644 index 0000000..8458339 --- /dev/null +++ b/hscontrol/policy/matcher/matcher.go @@ -0,0 +1,61 @@ +package matcher + +import ( + "net/netip" + + "github.com/juanfont/headscale/hscontrol/util" + "go4.org/netipx" + "tailscale.com/tailcfg" +) + +type Match struct { + Srcs *netipx.IPSet + Dests *netipx.IPSet +} + +func MatchFromFilterRule(rule tailcfg.FilterRule) Match { + srcs := new(netipx.IPSetBuilder) + dests := new(netipx.IPSetBuilder) + + for _, srcIP := range rule.SrcIPs { + set, _ := util.ParseIPSet(srcIP, nil) + + srcs.AddSet(set) + } + + for _, dest := range rule.DstPorts { + set, _ := util.ParseIPSet(dest.IP, nil) + + dests.AddSet(set) + } + + srcsSet, _ := srcs.IPSet() + destsSet, _ := dests.IPSet() + + match := Match{ + Srcs: srcsSet, + Dests: destsSet, + } + + return match +} + +func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { + for _, ip := range ips { + if m.Srcs.Contains(ip) { + return true + } + } + + return false +} + +func (m *Match) DestsContainsIP(ips []netip.Addr) bool { + for _, ip := range ips { + if m.Dests.Contains(ip) { + return true + } + } + + return false +} diff --git a/hscontrol/policy/matcher/matcher_test.go b/hscontrol/policy/matcher/matcher_test.go new file mode 100644 index 0000000..54cf8a0 --- /dev/null +++ b/hscontrol/policy/matcher/matcher_test.go @@ -0,0 +1 @@ +package matcher diff --git a/hscontrol/protocol_common.go b/hscontrol/protocol_common.go index 5cd0ddb..ae034fb 100644 --- a/hscontrol/protocol_common.go +++ b/hscontrol/protocol_common.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" @@ -171,7 +172,7 @@ func (h *Headscale) handleRegisterCommon( // that we rely on a method that calls back some how (OpenID or CLI) // We create the machine and then keep it around until a callback // happens - newMachine := Machine{ + newMachine := types.Machine{ MachineKey: util.MachinePublicKeyStripPrefix(machineKey), Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, @@ -214,8 +215,7 @@ func (h *Headscale) handleRegisterCommon( []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), ) if err != nil || storedMachineKey.IsZero() { - machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey) - if err := h.db.db.Save(&machine).Error; err != nil { + if err := h.db.MachineSetMachineKey(machine, machineKey); err != nil { log.Error(). Caller(). Str("func", "RegistrationHandler"). @@ -244,7 +244,7 @@ func (h *Headscale) handleRegisterCommon( // If machine is not expired, and it is register, we have a already accepted this machine, // let it proceed with a valid registration - if !machine.isExpired() { + if !machine.IsExpired() { h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise) return @@ -253,7 +253,7 @@ func (h *Headscale) handleRegisterCommon( // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && - !machine.isExpired() { + !machine.IsExpired() { h.handleMachineRefreshKeyCommon( writer, registerRequest, @@ -312,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon( Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) resp := tailcfg.RegisterResponse{} - pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey) + pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey) if err != nil { log.Error(). Caller(). @@ -333,7 +333,7 @@ func (h *Headscale) handleAuthKeyCommon( Err(err). Msg("Cannot encode message") http.Error(writer, "Internal server error", http.StatusInternalServerError) - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). Inc() return @@ -358,10 +358,10 @@ func (h *Headscale) handleAuthKeyCommon( Msg("Failed authentication via AuthKey") if pak != nil { - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). Inc() } else { - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc() + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc() } return @@ -401,10 +401,10 @@ func (h *Headscale) handleAuthKeyCommon( return } - aclTags := pak.toProto().AclTags + aclTags := pak.Proto().AclTags if len(aclTags) > 0 { // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login - err = h.db.SetTags(machine, aclTags, h.UpdateACLRules) + err = h.db.SetTags(machine, aclTags) if err != nil { log.Error(). @@ -433,17 +433,17 @@ func (h *Headscale) handleAuthKeyCommon( return } - machineToRegister := Machine{ + machineToRegister := types.Machine{ Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, UserID: pak.User.ID, MachineKey: util.MachinePublicKeyStripPrefix(machineKey), - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, Expiry: ®isterRequest.Expiry, NodeKey: nodeKey, LastSeen: &now, AuthKeyID: uint(pak.ID), - ForcedTags: pak.toProto().AclTags, + ForcedTags: pak.Proto().AclTags, } machine, err = h.db.RegisterMachine( @@ -455,7 +455,7 @@ func (h *Headscale) handleAuthKeyCommon( Bool("noise", isNoise). Err(err). Msg("could not register machine") - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). Inc() http.Error(writer, "Internal server error", http.StatusInternalServerError) @@ -470,7 +470,7 @@ func (h *Headscale) handleAuthKeyCommon( Bool("noise", isNoise). Err(err). Msg("Failed to use pre-auth key") - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). Inc() http.Error(writer, "Internal server error", http.StatusInternalServerError) @@ -478,10 +478,10 @@ func (h *Headscale) handleAuthKeyCommon( } resp.MachineAuthorized = true - resp.User = *pak.User.toTailscaleUser() + resp.User = *pak.User.TailscaleUser() // Provide LoginName when registering with pre-auth key // Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* - resp.Login = *pak.User.toTailscaleLogin() + resp.Login = *pak.User.TailscaleLogin() respBody, err := h.marshalResponse(resp, machineKey, isNoise) if err != nil { @@ -492,13 +492,13 @@ func (h *Headscale) handleAuthKeyCommon( Str("machine", registerRequest.Hostinfo.Hostname). Err(err). Msg("Cannot encode message") - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). Inc() http.Error(writer, "Internal server error", http.StatusInternalServerError) return } - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name). Inc() writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) @@ -581,7 +581,7 @@ func (h *Headscale) handleNewMachineCommon( func (h *Headscale) handleMachineLogOutCommon( writer http.ResponseWriter, - machine Machine, + machine types.Machine, machineKey key.MachinePublic, isNoise bool, ) { @@ -608,7 +608,7 @@ func (h *Headscale) handleMachineLogOutCommon( resp.AuthURL = "" resp.MachineAuthorized = false resp.NodeKeyExpired = true - resp.User = *machine.User.toTailscaleUser() + resp.User = *machine.User.TailscaleUser() respBody, err := h.marshalResponse(resp, machineKey, isNoise) if err != nil { log.Error(). @@ -634,7 +634,7 @@ func (h *Headscale) handleMachineLogOutCommon( return } - if machine.isEphemeral() { + if machine.IsEphemeral() { err = h.db.HardDeleteMachine(&machine) if err != nil { log.Error(). @@ -655,7 +655,7 @@ func (h *Headscale) handleMachineLogOutCommon( func (h *Headscale) handleMachineValidRegistrationCommon( writer http.ResponseWriter, - machine Machine, + machine types.Machine, machineKey key.MachinePublic, isNoise bool, ) { @@ -670,8 +670,8 @@ func (h *Headscale) handleMachineValidRegistrationCommon( resp.AuthURL = "" resp.MachineAuthorized = true - resp.User = *machine.User.toTailscaleUser() - resp.Login = *machine.User.toTailscaleLogin() + resp.User = *machine.User.TailscaleUser() + resp.Login = *machine.User.TailscaleLogin() respBody, err := h.marshalResponse(resp, machineKey, isNoise) if err != nil { @@ -710,7 +710,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon( func (h *Headscale) handleMachineRefreshKeyCommon( writer http.ResponseWriter, registerRequest tailcfg.RegisterRequest, - machine Machine, + machine types.Machine, machineKey key.MachinePublic, isNoise bool, ) { @@ -721,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon( Bool("noise", isNoise). Str("machine", machine.Hostname). Msg("We have the OldNodeKey in the database. This is a key refresh") - machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey) - if err := h.db.db.Save(&machine).Error; err != nil { + err := h.db.MachineSetNodeKey(&machine, registerRequest.NodeKey) + if err != nil { log.Error(). Caller(). Err(err). @@ -734,7 +734,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon( } resp.AuthURL = "" - resp.User = *machine.User.toTailscaleUser() + resp.User = *machine.User.TailscaleUser() respBody, err := h.marshalResponse(resp, machineKey, isNoise) if err != nil { log.Error(). @@ -770,7 +770,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon( func (h *Headscale) handleMachineExpiredOrLoggedOutCommon( writer http.ResponseWriter, registerRequest tailcfg.RegisterRequest, - machine Machine, + machine types.Machine, machineKey key.MachinePublic, isNoise bool, ) { diff --git a/hscontrol/protocol_common_poll.go b/hscontrol/protocol_common_poll.go index 502c633..3d43238 100644 --- a/hscontrol/protocol_common_poll.go +++ b/hscontrol/protocol_common_poll.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" @@ -24,16 +25,16 @@ const machineNameContextKey = contextKey("machineName") func (h *Headscale) handlePollCommon( writer http.ResponseWriter, ctx context.Context, - machine *Machine, + machine *types.Machine, mapRequest tailcfg.MapRequest, isNoise bool, ) { machine.Hostname = mapRequest.Hostinfo.Hostname - machine.HostInfo = HostInfo(*mapRequest.Hostinfo) + machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo) machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) now := time.Now().UTC() - err := h.db.processMachineRoutes(machine) + err := h.db.ProcessMachineRoutes(machine) if err != nil { log.Error(). Caller(). @@ -43,18 +44,13 @@ func (h *Headscale) handlePollCommon( } // update ACLRules with peer informations (to update server tags if necessary) - if h.aclPolicy != nil { - err := h.UpdateACLRules() - if err != nil { - log.Error(). - Caller(). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Err(err) - } + if h.ACLPolicy != nil { + // TODO(kradalby): Since this is not blocking, I might have introduced a bug here. + // It will be resolved later as we change up the policy stuff. + h.policyUpdateChan <- struct{}{} // update routes with peer information - err = h.db.EnableAutoApprovedRoutes(h.aclPolicy, machine) + err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine) if err != nil { log.Error(). Caller(). @@ -78,19 +74,17 @@ func (h *Headscale) handlePollCommon( machine.LastSeen = &now } - if err := h.db.db.Updates(machine).Error; err != nil { - if err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("node_key", machine.NodeKey). - Str("machine", machine.Hostname). - Err(err). - Msg("Failed to persist/update machine in the database") - http.Error(writer, "", http.StatusInternalServerError) + if err := h.db.MachineSave(machine); err != nil { + log.Error(). + Str("handler", "PollNetMap"). + Bool("noise", isNoise). + Str("node_key", machine.NodeKey). + Str("machine", machine.Hostname). + Err(err). + Msg("Failed to persist/update machine in the database") + http.Error(writer, "", http.StatusInternalServerError) - return - } + return } mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise) @@ -244,7 +238,7 @@ func (h *Headscale) handlePollCommon( func (h *Headscale) pollNetMapStream( writer http.ResponseWriter, ctxReq context.Context, - machine *Machine, + machine *types.Machine, mapRequest tailcfg.MapRequest, pollDataChan chan []byte, keepAliveChan chan []byte, @@ -457,7 +451,7 @@ func (h *Headscale) pollNetMapStream( updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname). Inc() - if h.db.isOutdated(machine, h.getLastStateChange()) { + if h.db.IsOutdated(machine, h.getLastStateChange()) { var lastUpdate time.Time if machine.LastSuccessfulUpdate != nil { lastUpdate = *machine.LastSuccessfulUpdate @@ -626,7 +620,7 @@ func (h *Headscale) scheduledPollWorker( updateChan chan struct{}, keepAliveChan chan []byte, mapRequest tailcfg.MapRequest, - machine *Machine, + machine *types.Machine, isNoise bool, ) { keepAliveTicker := time.NewTicker(keepAliveInterval) diff --git a/hscontrol/protocol_common_utils.go b/hscontrol/protocol_common_utils.go index 1dababa..8990eeb 100644 --- a/hscontrol/protocol_common_utils.go +++ b/hscontrol/protocol_common_utils.go @@ -5,6 +5,7 @@ import ( "encoding/json" "sync" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" @@ -15,7 +16,7 @@ import ( func (h *Headscale) getMapResponseData( mapRequest tailcfg.MapRequest, - machine *Machine, + machine *types.Machine, isNoise bool, ) ([]byte, error) { mapResponse, err := h.generateMapResponse(mapRequest, machine) @@ -43,7 +44,7 @@ func (h *Headscale) getMapResponseData( func (h *Headscale) getMapKeepAliveResponseData( mapRequest tailcfg.MapRequest, - machine *Machine, + machine *types.Machine, isNoise bool, ) ([]byte, error) { keepAliveResponse := tailcfg.MapResponse{ diff --git a/hscontrol/app_test.go b/hscontrol/suite_test.go similarity index 54% rename from hscontrol/app_test.go rename to hscontrol/suite_test.go index 1b4e91e..69a651a 100644 --- a/hscontrol/app_test.go +++ b/hscontrol/suite_test.go @@ -18,7 +18,7 @@ type Suite struct{} var ( tmpDir string - app Headscale + app *Headscale ) func (s *Suite) SetUpTest(c *check.C) { @@ -34,11 +34,15 @@ func (s *Suite) ResetDB(c *check.C) { os.RemoveAll(tmpDir) } var err error - tmpDir, err = os.MkdirTemp("", "autoygg-client-test") + tmpDir, err = os.MkdirTemp("", "autoygg-client-test2") if err != nil { c.Fatal(err) } cfg := Config{ + PrivateKeyPath: tmpDir + "/private.key", + NoisePrivateKeyPath: tmpDir + "/noise_private.key", + DBtype: "sqlite3", + DBpath: tmpDir + "/headscale_test.db", IPPrefixes: []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, @@ -47,29 +51,8 @@ func (s *Suite) ResetDB(c *check.C) { }, } - // TODO(kradalby): make this use NewHeadscale properly so it doesnt drift - app = Headscale{ - cfg: &cfg, - dbType: "sqlite3", - dbString: tmpDir + "/headscale_test.db", - - stateUpdateChan: make(chan struct{}), - cancelStateUpdateChan: make(chan struct{}), - } - - go app.watchStateChannel() - - db, err := NewHeadscaleDatabase( - app.dbType, - app.dbString, - cfg.OIDC.StripEmaildomain, - false, - app.stateUpdateChan, - cfg.IPPrefixes, - "", - ) + app, err = NewHeadscale(&cfg) if err != nil { c.Fatal(err) } - app.db = db } diff --git a/hscontrol/types/api_key.go b/hscontrol/types/api_key.go new file mode 100644 index 0000000..8ca0004 --- /dev/null +++ b/hscontrol/types/api_key.go @@ -0,0 +1,41 @@ +package types + +import ( + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// APIKey describes the datamodel for API keys used to remotely authenticate with +// headscale. +type APIKey struct { + ID uint64 `gorm:"primary_key"` + Prefix string `gorm:"uniqueIndex"` + Hash []byte + + CreatedAt *time.Time + Expiration *time.Time + LastSeen *time.Time +} + +func (key *APIKey) Proto() *v1.ApiKey { + protoKey := v1.ApiKey{ + Id: key.ID, + Prefix: key.Prefix, + } + + if key.Expiration != nil { + protoKey.Expiration = timestamppb.New(*key.Expiration) + } + + if key.CreatedAt != nil { + protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) + } + + if key.LastSeen != nil { + protoKey.LastSeen = timestamppb.New(*key.LastSeen) + } + + return &protoKey +} diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go new file mode 100644 index 0000000..96ad1b7 --- /dev/null +++ b/hscontrol/types/common.go @@ -0,0 +1,108 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "net/netip" + + "tailscale.com/tailcfg" +) + +var ErrCannotParsePrefix = errors.New("cannot parse prefix") + +// This is a "wrapper" type around tailscales +// Hostinfo to allow us to add database "serialization" +// methods. This allows us to use a typed values throughout +// the code and not have to marshal/unmarshal and error +// check all over the code. +type HostInfo tailcfg.Hostinfo + +func (hi *HostInfo) Scan(destination interface{}) error { + switch value := destination.(type) { + case []byte: + return json.Unmarshal(value, hi) + + case string: + return json.Unmarshal([]byte(value), hi) + + default: + return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (hi HostInfo) Value() (driver.Value, error) { + bytes, err := json.Marshal(hi) + + return string(bytes), err +} + +type IPPrefix netip.Prefix + +func (i *IPPrefix) Scan(destination interface{}) error { + switch value := destination.(type) { + case string: + prefix, err := netip.ParsePrefix(value) + if err != nil { + return err + } + *i = IPPrefix(prefix) + + return nil + default: + return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (i IPPrefix) Value() (driver.Value, error) { + prefixStr := netip.Prefix(i).String() + + return prefixStr, nil +} + +type IPPrefixes []netip.Prefix + +func (i *IPPrefixes) Scan(destination interface{}) error { + switch value := destination.(type) { + case []byte: + return json.Unmarshal(value, i) + + case string: + return json.Unmarshal([]byte(value), i) + + default: + return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (i IPPrefixes) Value() (driver.Value, error) { + bytes, err := json.Marshal(i) + + return string(bytes), err +} + +type StringList []string + +func (i *StringList) Scan(destination interface{}) error { + switch value := destination.(type) { + case []byte: + return json.Unmarshal(value, i) + + case string: + return json.Unmarshal([]byte(value), i) + + default: + return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (i StringList) Value() (driver.Value, error) { + bytes, err := json.Marshal(i) + + return string(bytes), err +} diff --git a/hscontrol/types/machine.go b/hscontrol/types/machine.go new file mode 100644 index 0000000..a4ca03e --- /dev/null +++ b/hscontrol/types/machine.go @@ -0,0 +1,254 @@ +package types + +import ( + "database/sql/driver" + "errors" + "fmt" + "net/netip" + "strings" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "go4.org/netipx" + "google.golang.org/protobuf/types/known/timestamppb" + "tailscale.com/tailcfg" +) + +const ( + // TODO(kradalby): Move out of here when we got circdeps under control. + keepAliveInterval = 60 * time.Second +) + +var ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses") + +// Machine is a Headscale client. +type Machine struct { + ID uint64 `gorm:"primary_key"` + MachineKey string `gorm:"type:varchar(64);unique_index"` + NodeKey string + DiscoKey string + IPAddresses MachineAddresses + + // Hostname represents the name given by the Tailscale + // client during registration + Hostname string + + // Givenname represents either: + // a DNS normalized version of Hostname + // a valid name set by the User + // + // GivenName is the name used in all DNS related + // parts of headscale. + GivenName string `gorm:"type:varchar(63);unique_index"` + UserID uint + User User `gorm:"foreignKey:UserID"` + + RegisterMethod string + + ForcedTags StringList + + // TODO(kradalby): This seems like irrelevant information? + AuthKeyID uint + AuthKey *PreAuthKey + + LastSeen *time.Time + LastSuccessfulUpdate *time.Time + Expiry *time.Time + + HostInfo HostInfo + Endpoints StringList + + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time +} + +type ( + Machines []Machine + MachinesP []*Machine +) + +type MachineAddresses []netip.Addr + +func (ma MachineAddresses) ToStringSlice() []string { + strSlice := make([]string, 0, len(ma)) + for _, addr := range ma { + strSlice = append(strSlice, addr.String()) + } + + return strSlice +} + +// AppendToIPSet adds the individual ips in MachineAddresses to a +// given netipx.IPSetBuilder. +func (ma MachineAddresses) AppendToIPSet(build *netipx.IPSetBuilder) { + for _, ip := range ma { + build.Add(ip) + } +} + +func (ma *MachineAddresses) Scan(destination interface{}) error { + switch value := destination.(type) { + case string: + addresses := strings.Split(value, ",") + *ma = (*ma)[:0] + for _, addr := range addresses { + if len(addr) < 1 { + continue + } + parsed, err := netip.ParseAddr(addr) + if err != nil { + return err + } + *ma = append(*ma, parsed) + } + + return nil + + default: + return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (ma MachineAddresses) Value() (driver.Value, error) { + addresses := strings.Join(ma.ToStringSlice(), ",") + + return addresses, nil +} + +// IsExpired returns whether the machine registration has expired. +func (machine Machine) IsExpired() bool { + // If Expiry is not set, the client has not indicated that + // it wants an expiry time, it is therefor considered + // to mean "not expired" + if machine.Expiry == nil || machine.Expiry.IsZero() { + return false + } + + return time.Now().UTC().After(*machine.Expiry) +} + +// IsOnline returns if the machine is connected to Headscale. +// This is really a naive implementation, as we don't really see +// if there is a working connection between the client and the server. +func (machine *Machine) IsOnline() bool { + if machine.LastSeen == nil { + return false + } + + if machine.IsExpired() { + return false + } + + return machine.LastSeen.After(time.Now().Add(-keepAliveInterval)) +} + +// IsEphemeral returns if the machine is registered as an Ephemeral node. +// https://tailscale.com/kb/1111/ephemeral-nodes/ +func (machine *Machine) IsEphemeral() bool { + return machine.AuthKey != nil && machine.AuthKey.Ephemeral +} + +func (machine *Machine) CanAccess(filter []tailcfg.FilterRule, machine2 *Machine) bool { + for _, rule := range filter { + // TODO(kradalby): Cache or pregen this + matcher := matcher.MatchFromFilterRule(rule) + + if !matcher.SrcsContainsIPs([]netip.Addr(machine.IPAddresses)) { + continue + } + + if matcher.DestsContainsIP([]netip.Addr(machine2.IPAddresses)) { + return true + } + } + + return false +} + +func (machines Machines) FilterByIP(ip netip.Addr) Machines { + found := make(Machines, 0) + + for _, machine := range machines { + for _, mIP := range machine.IPAddresses { + if ip == mIP { + found = append(found, machine) + } + } + } + + return found +} + +func (machine *Machine) Proto() *v1.Machine { + machineProto := &v1.Machine{ + Id: machine.ID, + MachineKey: machine.MachineKey, + + NodeKey: machine.NodeKey, + DiscoKey: machine.DiscoKey, + IpAddresses: machine.IPAddresses.ToStringSlice(), + Name: machine.Hostname, + GivenName: machine.GivenName, + User: machine.User.Proto(), + ForcedTags: machine.ForcedTags, + Online: machine.IsOnline(), + + // TODO(kradalby): Implement register method enum converter + // RegisterMethod: , + + CreatedAt: timestamppb.New(machine.CreatedAt), + } + + if machine.AuthKey != nil { + machineProto.PreAuthKey = machine.AuthKey.Proto() + } + + if machine.LastSeen != nil { + machineProto.LastSeen = timestamppb.New(*machine.LastSeen) + } + + if machine.LastSuccessfulUpdate != nil { + machineProto.LastSuccessfulUpdate = timestamppb.New( + *machine.LastSuccessfulUpdate, + ) + } + + if machine.Expiry != nil { + machineProto.Expiry = timestamppb.New(*machine.Expiry) + } + + return machineProto +} + +// GetHostInfo returns a Hostinfo struct for the machine. +func (machine *Machine) GetHostInfo() tailcfg.Hostinfo { + return tailcfg.Hostinfo(machine.HostInfo) +} + +func (machine Machine) String() string { + return machine.Hostname +} + +func (machines Machines) String() string { + temp := make([]string, len(machines)) + + for index, machine := range machines { + temp[index] = machine.Hostname + } + + return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) +} + +// TODO(kradalby): Remove when we have generics... +func (machines MachinesP) String() string { + temp := make([]string, len(machines)) + + for index, machine := range machines { + temp[index] = machine.Hostname + } + + return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) +} diff --git a/hscontrol/types/machine_test.go b/hscontrol/types/machine_test.go new file mode 100644 index 0000000..ab1254f --- /dev/null +++ b/hscontrol/types/machine_test.go @@ -0,0 +1 @@ +package types diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go new file mode 100644 index 0000000..0d8c9cf --- /dev/null +++ b/hscontrol/types/preauth_key.go @@ -0,0 +1,58 @@ +package types + +import ( + "strconv" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// PreAuthKey describes a pre-authorization key usable in a particular user. +type PreAuthKey struct { + ID uint64 `gorm:"primary_key"` + Key string + UserID uint + User User + Reusable bool + Ephemeral bool `gorm:"default:false"` + Used bool `gorm:"default:false"` + ACLTags []PreAuthKeyACLTag + + CreatedAt *time.Time + Expiration *time.Time +} + +// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey. +type PreAuthKeyACLTag struct { + ID uint64 `gorm:"primary_key"` + PreAuthKeyID uint64 + Tag string +} + +func (key *PreAuthKey) Proto() *v1.PreAuthKey { + protoKey := v1.PreAuthKey{ + User: key.User.Name, + Id: strconv.FormatUint(key.ID, util.Base10), + Key: key.Key, + Ephemeral: key.Ephemeral, + Reusable: key.Reusable, + Used: key.Used, + AclTags: make([]string, len(key.ACLTags)), + } + + if key.Expiration != nil { + protoKey.Expiration = timestamppb.New(*key.Expiration) + } + + if key.CreatedAt != nil { + protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) + } + + for idx := range key.ACLTags { + protoKey.AclTags[idx] = key.ACLTags[idx].Tag + } + + return &protoKey +} diff --git a/hscontrol/types/routes.go b/hscontrol/types/routes.go new file mode 100644 index 0000000..1f43071 --- /dev/null +++ b/hscontrol/types/routes.go @@ -0,0 +1,71 @@ +package types + +import ( + "fmt" + "net/netip" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "google.golang.org/protobuf/types/known/timestamppb" + "gorm.io/gorm" +) + +var ( + ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") + ExitRouteV6 = netip.MustParsePrefix("::/0") +) + +type Route struct { + gorm.Model + + MachineID uint64 + Machine Machine + Prefix IPPrefix + + Advertised bool + Enabled bool + IsPrimary bool +} + +type Routes []Route + +func (r *Route) String() string { + return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) +} + +func (r *Route) IsExitRoute() bool { + return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 +} + +func (rs Routes) Prefixes() []netip.Prefix { + prefixes := make([]netip.Prefix, len(rs)) + for i, r := range rs { + prefixes[i] = netip.Prefix(r.Prefix) + } + + return prefixes +} + +func (rs Routes) Proto() []*v1.Route { + protoRoutes := []*v1.Route{} + + for _, route := range rs { + protoRoute := v1.Route{ + Id: uint64(route.ID), + Machine: route.Machine.Proto(), + Prefix: netip.Prefix(route.Prefix).String(), + Advertised: route.Advertised, + Enabled: route.Enabled, + IsPrimary: route.IsPrimary, + CreatedAt: timestamppb.New(route.CreatedAt), + UpdatedAt: timestamppb.New(route.UpdatedAt), + } + + if route.DeletedAt.Valid { + protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time) + } + + protoRoutes = append(protoRoutes, &protoRoute) + } + + return protoRoutes +} diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go new file mode 100644 index 0000000..d5e3c45 --- /dev/null +++ b/hscontrol/types/users.go @@ -0,0 +1,55 @@ +package types + +import ( + "strconv" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" + "google.golang.org/protobuf/types/known/timestamppb" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +// User is the way Headscale implements the concept of users in Tailscale +// +// At the end of the day, users in Tailscale are some kind of 'bubbles' or users +// that contain our machines. +type User struct { + gorm.Model + Name string `gorm:"unique"` +} + +func (n *User) TailscaleUser() *tailcfg.User { + user := tailcfg.User{ + ID: tailcfg.UserID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + ProfilePicURL: "", + Domain: "headscale.net", + Logins: []tailcfg.LoginID{}, + Created: time.Time{}, + } + + return &user +} + +func (n *User) TailscaleLogin() *tailcfg.Login { + login := tailcfg.Login{ + ID: tailcfg.LoginID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + ProfilePicURL: "", + Domain: "headscale.net", + } + + return &login +} + +func (n *User) Proto() *v1.User { + return &v1.User{ + Id: strconv.FormatUint(uint64(n.ID), util.Base10), + Name: n.Name, + CreatedAt: timestamppb.New(n.CreatedAt), + } +} diff --git a/hscontrol/users_test.go b/hscontrol/users_test.go deleted file mode 100644 index 1d68f92..0000000 --- a/hscontrol/users_test.go +++ /dev/null @@ -1,415 +0,0 @@ -package hscontrol - -import ( - "net/netip" - "testing" - - "gopkg.in/check.v1" - "gorm.io/gorm" -) - -func (s *Suite) TestCreateAndDestroyUser(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - c.Assert(user.Name, check.Equals, "test") - - users, err := app.db.ListUsers() - c.Assert(err, check.IsNil) - c.Assert(len(users), check.Equals, 1) - - err = app.db.DestroyUser("test") - c.Assert(err, check.IsNil) - - _, err = app.db.GetUser("test") - c.Assert(err, check.NotNil) -} - -func (s *Suite) TestDestroyUserErrors(c *check.C) { - err := app.db.DestroyUser("test") - c.Assert(err, check.Equals, ErrUserNotFound) - - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - err = app.db.DestroyUser("test") - c.Assert(err, check.IsNil) - - result := app.db.db.Preload("User").First(&pak, "key = ?", pak.Key) - // destroying a user also deletes all associated preauthkeys - c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) - - user, err = app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - err = app.db.DestroyUser("test") - c.Assert(err, check.Equals, ErrUserStillHasNodes) -} - -func (s *Suite) TestRenameUser(c *check.C) { - userTest, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - c.Assert(userTest.Name, check.Equals, "test") - - users, err := app.db.ListUsers() - c.Assert(err, check.IsNil) - c.Assert(len(users), check.Equals, 1) - - err = app.db.RenameUser("test", "test-renamed") - c.Assert(err, check.IsNil) - - _, err = app.db.GetUser("test") - c.Assert(err, check.Equals, ErrUserNotFound) - - _, err = app.db.GetUser("test-renamed") - c.Assert(err, check.IsNil) - - err = app.db.RenameUser("test-does-not-exit", "test") - c.Assert(err, check.Equals, ErrUserNotFound) - - userTest2, err := app.db.CreateUser("test2") - c.Assert(err, check.IsNil) - c.Assert(userTest2.Name, check.Equals, "test2") - - err = app.db.RenameUser("test2", "test-renamed") - c.Assert(err, check.Equals, ErrUserExists) -} - -func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { - userShared1, err := app.db.CreateUser("shared1") - c.Assert(err, check.IsNil) - - userShared2, err := app.db.CreateUser("shared2") - c.Assert(err, check.IsNil) - - userShared3, err := app.db.CreateUser("shared3") - c.Assert(err, check.IsNil) - - preAuthKeyShared1, err := app.db.CreatePreAuthKey( - userShared1.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKeyShared2, err := app.db.CreatePreAuthKey( - userShared2.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKeyShared3, err := app.db.CreatePreAuthKey( - userShared3.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKey2Shared1, err := app.db.CreatePreAuthKey( - userShared1.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") - c.Assert(err, check.NotNil) - - machineInShared1 := &Machine{ - ID: 1, - MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - Hostname: "test_get_shared_nodes_1", - UserID: userShared1.ID, - User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - AuthKeyID: uint(preAuthKeyShared1.ID), - } - app.db.db.Save(machineInShared1) - - _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) - c.Assert(err, check.IsNil) - - machineInShared2 := &Machine{ - ID: 2, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_2", - UserID: userShared2.ID, - User: *userShared2, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, - AuthKeyID: uint(preAuthKeyShared2.ID), - } - app.db.db.Save(machineInShared2) - - _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) - c.Assert(err, check.IsNil) - - machineInShared3 := &Machine{ - ID: 3, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_3", - UserID: userShared3.ID, - User: *userShared3, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, - AuthKeyID: uint(preAuthKeyShared3.ID), - } - app.db.db.Save(machineInShared3) - - _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) - c.Assert(err, check.IsNil) - - machine2InShared1 := &Machine{ - ID: 4, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_4", - UserID: userShared1.ID, - User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, - AuthKeyID: uint(preAuthKey2Shared1.ID), - } - app.db.db.Save(machine2InShared1) - - peersOfMachine1InShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) - c.Assert(err, check.IsNil) - - userProfiles := app.db.getMapResponseUserProfiles( - *machineInShared1, - peersOfMachine1InShared1, - ) - - c.Assert(len(userProfiles), check.Equals, 3) - - found := false - for _, userProfiles := range userProfiles { - if userProfiles.DisplayName == userShared1.Name { - found = true - - break - } - } - c.Assert(found, check.Equals, true) - - found = false - for _, userProfile := range userProfiles { - if userProfile.DisplayName == userShared2.Name { - found = true - - break - } - } - c.Assert(found, check.Equals, true) -} - -func TestNormalizeToFQDNRules(t *testing.T) { - type args struct { - name string - stripEmailDomain bool - } - tests := []struct { - name string - args args - want string - wantErr bool - }{ - { - name: "normalize simple name", - args: args{ - name: "normalize-simple.name", - stripEmailDomain: false, - }, - want: "normalize-simple.name", - wantErr: false, - }, - { - name: "normalize an email", - args: args{ - name: "foo.bar@example.com", - stripEmailDomain: false, - }, - want: "foo.bar.example.com", - wantErr: false, - }, - { - name: "normalize an email domain should be removed", - args: args{ - name: "foo.bar@example.com", - stripEmailDomain: true, - }, - want: "foo.bar", - wantErr: false, - }, - { - name: "strip enabled no email passed as argument", - args: args{ - name: "not-email-and-strip-enabled", - stripEmailDomain: true, - }, - want: "not-email-and-strip-enabled", - wantErr: false, - }, - { - name: "normalize complex email", - args: args{ - name: "foo.bar+complex-email@example.com", - stripEmailDomain: false, - }, - want: "foo.bar-complex-email.example.com", - wantErr: false, - }, - { - name: "user name with space", - args: args{ - name: "name space", - stripEmailDomain: false, - }, - want: "name-space", - wantErr: false, - }, - { - name: "user with quote", - args: args{ - name: "Jamie's iPhone 5", - stripEmailDomain: false, - }, - want: "jamies-iphone-5", - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain) - if (err != nil) != tt.wantErr { - t.Errorf( - "NormalizeToFQDNRules() error = %v, wantErr %v", - err, - tt.wantErr, - ) - - return - } - if got != tt.want { - t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestCheckForFQDNRules(t *testing.T) { - type args struct { - name string - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "valid: user", - args: args{name: "valid-user"}, - wantErr: false, - }, - { - name: "invalid: capitalized user", - args: args{name: "Invalid-CapItaLIzed-user"}, - wantErr: true, - }, - { - name: "invalid: email as user", - args: args{name: "foo.bar@example.com"}, - wantErr: true, - }, - { - name: "invalid: chars in user name", - args: args{name: "super-user+name"}, - wantErr: true, - }, - { - name: "invalid: too long name for user", - args: args{ - name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars", - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr { - t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func (s *Suite) TestSetMachineUser(c *check.C) { - oldUser, err := app.db.CreateUser("old") - c.Assert(err, check.IsNil) - - newUser, err := app.db.CreateUser("new") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: oldUser.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - c.Assert(machine.UserID, check.Equals, oldUser.ID) - - err = app.db.SetMachineUser(&machine, newUser.Name) - c.Assert(err, check.IsNil) - c.Assert(machine.UserID, check.Equals, newUser.ID) - c.Assert(machine.User.Name, check.Equals, newUser.Name) - - err = app.db.SetMachineUser(&machine, "non-existing-user") - c.Assert(err, check.Equals, ErrUserNotFound) - - err = app.db.SetMachineUser(&machine, newUser.Name) - c.Assert(err, check.IsNil) - c.Assert(machine.UserID, check.Equals, newUser.ID) - c.Assert(machine.User.Name, check.Equals, newUser.Name) -} diff --git a/hscontrol/util/addr.go b/hscontrol/util/addr.go index d312a6e..5c02c93 100644 --- a/hscontrol/util/addr.go +++ b/hscontrol/util/addr.go @@ -1,12 +1,94 @@ package util import ( + "fmt" "net/netip" "reflect" + "strings" "go4.org/netipx" ) +// This is borrowed from, and updated to use IPSet +// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 +// TODO(kradalby): contribute upstream and make public. +var ( + zeroIP4 = netip.AddrFrom4([4]byte{}) + zeroIP6 = netip.AddrFrom16([16]byte{}) +) + +// parseIPSet parses arg as one: +// +// - an IP address (IPv4 or IPv6) +// - the string "*" to match everything (both IPv4 & IPv6) +// - a CIDR (e.g. "192.168.0.0/16") +// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") +// +// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP +// address (without a slash) treated as a CIDR of *bits length. +// nolint +func ParseIPSet(arg string, bits *int) (*netipx.IPSet, error) { + var ipSet netipx.IPSetBuilder + if arg == "*" { + ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0)) + ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0)) + + return ipSet.IPSet() + } + if strings.Contains(arg, "/") { + pfx, err := netip.ParsePrefix(arg) + if err != nil { + return nil, err + } + if pfx != pfx.Masked() { + return nil, fmt.Errorf("%v contains non-network bits set", pfx) + } + + ipSet.AddPrefix(pfx) + + return ipSet.IPSet() + } + if strings.Count(arg, "-") == 1 { + ip1s, ip2s, _ := strings.Cut(arg, "-") + + ip1, err := netip.ParseAddr(ip1s) + if err != nil { + return nil, err + } + + ip2, err := netip.ParseAddr(ip2s) + if err != nil { + return nil, err + } + + r := netipx.IPRangeFrom(ip1, ip2) + if !r.IsValid() { + return nil, fmt.Errorf("invalid IP range %q", arg) + } + + for _, prefix := range r.Prefixes() { + ipSet.AddPrefix(prefix) + } + + return ipSet.IPSet() + } + ip, err := netip.ParseAddr(arg) + if err != nil { + return nil, fmt.Errorf("invalid IP address %q", arg) + } + bits8 := uint8(ip.BitLen()) + if bits != nil { + if *bits < 0 || *bits > int(bits8) { + return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg) + } + bits8 = uint8(*bits) + } + + ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8))) + + return ipSet.IPSet() +} + func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { var network, broadcast netip.Addr ipRange := netipx.RangeOfPrefix(na) diff --git a/hscontrol/matcher_test.go b/hscontrol/util/addr_test.go similarity index 96% rename from hscontrol/matcher_test.go rename to hscontrol/util/addr_test.go index fb0e9b0..45b2b92 100644 --- a/hscontrol/matcher_test.go +++ b/hscontrol/util/addr_test.go @@ -1,4 +1,4 @@ -package hscontrol +package util import ( "net/netip" @@ -105,7 +105,7 @@ func Test_parseIPSet(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := parseIPSet(tt.args.arg, tt.args.bits) + got, err := ParseIPSet(tt.args.arg, tt.args.bits) if (err != nil) != tt.wantErr { t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr) diff --git a/hscontrol/util/const.go b/hscontrol/util/const.go new file mode 100644 index 0000000..4f7c811 --- /dev/null +++ b/hscontrol/util/const.go @@ -0,0 +1,7 @@ +package util + +const ( + RegisterMethodAuthKey = "authkey" + RegisterMethodOIDC = "oidc" + RegisterMethodCLI = "cli" +) diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go new file mode 100644 index 0000000..72af8f8 --- /dev/null +++ b/hscontrol/util/dns.go @@ -0,0 +1,69 @@ +package util + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +const ( + // value related to RFC 1123 and 952. + LabelHostnameLength = 63 +) + +var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") + +var ErrInvalidUserName = errors.New("invalid user name") + +// NormalizeToFQDNRules will replace forbidden chars in user +// it can also return an error if the user doesn't respect RFC 952 and 1123. +func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { + name = strings.ToLower(name) + name = strings.ReplaceAll(name, "'", "") + atIdx := strings.Index(name, "@") + if stripEmailDomain && atIdx > 0 { + name = name[:atIdx] + } else { + name = strings.ReplaceAll(name, "@", ".") + } + name = invalidCharsInUserRegex.ReplaceAllString(name, "-") + + for _, elt := range strings.Split(name, ".") { + if len(elt) > LabelHostnameLength { + return "", fmt.Errorf( + "label %v is more than 63 chars: %w", + elt, + ErrInvalidUserName, + ) + } + } + + return name, nil +} + +func CheckForFQDNRules(name string) error { + if len(name) > LabelHostnameLength { + return fmt.Errorf( + "DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w", + name, + ErrInvalidUserName, + ) + } + if strings.ToLower(name) != name { + return fmt.Errorf( + "DNS segment should be lowercase. %v doesn't comply with this rule: %w", + name, + ErrInvalidUserName, + ) + } + if invalidCharsInUserRegex.MatchString(name) { + return fmt.Errorf( + "DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", + name, + ErrInvalidUserName, + ) + } + + return nil +} diff --git a/hscontrol/util/dns_test.go b/hscontrol/util/dns_test.go new file mode 100644 index 0000000..ab66a13 --- /dev/null +++ b/hscontrol/util/dns_test.go @@ -0,0 +1,143 @@ +package util + +import "testing" + +func TestNormalizeToFQDNRules(t *testing.T) { + type args struct { + name string + stripEmailDomain bool + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "normalize simple name", + args: args{ + name: "normalize-simple.name", + stripEmailDomain: false, + }, + want: "normalize-simple.name", + wantErr: false, + }, + { + name: "normalize an email", + args: args{ + name: "foo.bar@example.com", + stripEmailDomain: false, + }, + want: "foo.bar.example.com", + wantErr: false, + }, + { + name: "normalize an email domain should be removed", + args: args{ + name: "foo.bar@example.com", + stripEmailDomain: true, + }, + want: "foo.bar", + wantErr: false, + }, + { + name: "strip enabled no email passed as argument", + args: args{ + name: "not-email-and-strip-enabled", + stripEmailDomain: true, + }, + want: "not-email-and-strip-enabled", + wantErr: false, + }, + { + name: "normalize complex email", + args: args{ + name: "foo.bar+complex-email@example.com", + stripEmailDomain: false, + }, + want: "foo.bar-complex-email.example.com", + wantErr: false, + }, + { + name: "user name with space", + args: args{ + name: "name space", + stripEmailDomain: false, + }, + want: "name-space", + wantErr: false, + }, + { + name: "user with quote", + args: args{ + name: "Jamie's iPhone 5", + stripEmailDomain: false, + }, + want: "jamies-iphone-5", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain) + if (err != nil) != tt.wantErr { + t.Errorf( + "NormalizeToFQDNRules() error = %v, wantErr %v", + err, + tt.wantErr, + ) + + return + } + if got != tt.want { + t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCheckForFQDNRules(t *testing.T) { + type args struct { + name string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "valid: user", + args: args{name: "valid-user"}, + wantErr: false, + }, + { + name: "invalid: capitalized user", + args: args{name: "Invalid-CapItaLIzed-user"}, + wantErr: true, + }, + { + name: "invalid: email as user", + args: args{name: "foo.bar@example.com"}, + wantErr: true, + }, + { + name: "invalid: chars in user name", + args: args{name: "super-user+name"}, + wantErr: true, + }, + { + name: "invalid: too long name for user", + args: args{ + name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr { + t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/integration/acl_test.go b/integration/acl_test.go index e85e28c..ca184b8 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" @@ -45,7 +45,7 @@ var veryLargeDestination = []string{ "208.0.0.0/4:*", } -func aclScenario(t *testing.T, policy *hscontrol.ACLPolicy, clientsPerUser int) *Scenario { +func aclScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario { t.Helper() scenario, err := NewScenario() assert.NoError(t, err) @@ -92,7 +92,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { // they can access minus one (them self). tests := map[string]struct { users map[string]int - policy hscontrol.ACLPolicy + policy policy.ACLPolicy want map[string]int }{ // Test that when we have no ACL, each client netmap has @@ -102,8 +102,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { "user1": 2, "user2": 2, }, - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, @@ -123,8 +123,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { "user1": 2, "user2": 2, }, - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -149,8 +149,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { "user1": 2, "user2": 2, }, - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -186,8 +186,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { "user1": 2, "user2": 2, }, - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -214,8 +214,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { "user1": 2, "user2": 2, }, - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -282,8 +282,8 @@ func TestACLAllowUser80Dst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + &policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -338,11 +338,11 @@ func TestACLDenyAllPort80(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:integration-acl-test": {"user1", "user2"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"group:integration-acl-test"}, @@ -387,8 +387,8 @@ func TestACLAllowUserDst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + &policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -445,8 +445,8 @@ func TestACLAllowStarDst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + &policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -504,11 +504,11 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &hscontrol.ACLPolicy{ - Hosts: hscontrol.Hosts{ + &policy.ACLPolicy{ + Hosts: policy.Hosts{ "all": netip.MustParsePrefix("100.64.0.0/24"), }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ // Everyone can curl test3 { Action: "accept", @@ -603,16 +603,16 @@ func TestACLNamedHostsCanReach(t *testing.T) { IntegrationSkip(t) tests := map[string]struct { - policy hscontrol.ACLPolicy + policy policy.ACLPolicy }{ "ipv4": { - policy: hscontrol.ACLPolicy{ - Hosts: hscontrol.Hosts{ + policy: policy.ACLPolicy{ + Hosts: policy.Hosts{ "test1": netip.MustParsePrefix("100.64.0.1/32"), "test2": netip.MustParsePrefix("100.64.0.2/32"), "test3": netip.MustParsePrefix("100.64.0.3/32"), }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ // Everyone can curl test3 { Action: "accept", @@ -629,13 +629,13 @@ func TestACLNamedHostsCanReach(t *testing.T) { }, }, "ipv6": { - policy: hscontrol.ACLPolicy{ - Hosts: hscontrol.Hosts{ + policy: policy.ACLPolicy{ + Hosts: policy.Hosts{ "test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), "test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), "test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"), }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ // Everyone can curl test3 { Action: "accept", @@ -854,11 +854,11 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { IntegrationSkip(t) tests := map[string]struct { - policy hscontrol.ACLPolicy + policy policy.ACLPolicy }{ "ipv4": { - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"100.64.0.1"}, @@ -868,8 +868,8 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { }, }, "ipv6": { - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"fd7a:115c:a1e0::1"}, @@ -879,12 +879,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { }, }, "hostv4cidr": { - policy: hscontrol.ACLPolicy{ - Hosts: hscontrol.Hosts{ + policy: policy.ACLPolicy{ + Hosts: policy.Hosts{ "test1": netip.MustParsePrefix("100.64.0.1/32"), "test2": netip.MustParsePrefix("100.64.0.2/32"), }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"test1"}, @@ -894,12 +894,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { }, }, "hostv6cidr": { - policy: hscontrol.ACLPolicy{ - Hosts: hscontrol.Hosts{ + policy: policy.ACLPolicy{ + Hosts: policy.Hosts{ "test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), "test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"test1"}, @@ -909,12 +909,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { }, }, "group": { - policy: hscontrol.ACLPolicy{ + policy: policy.ACLPolicy{ Groups: map[string][]string{ "group:one": {"user1"}, "group:two": {"user2"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"group:one"}, diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 0051b40..d27eb06 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -23,7 +23,7 @@ import ( "github.com/davecgh/go-spew/spew" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/integrationutil" @@ -60,7 +60,7 @@ type HeadscaleInContainer struct { port int extraPorts []string hostPortBindings map[string][]string - aclPolicy *hscontrol.ACLPolicy + aclPolicy *policy.ACLPolicy env map[string]string tlsCert []byte tlsKey []byte @@ -73,7 +73,7 @@ type Option = func(c *HeadscaleInContainer) // WithACLPolicy adds a hscontrol.ACLPolicy policy to the // HeadscaleInContainer instance. -func WithACLPolicy(acl *hscontrol.ACLPolicy) Option { +func WithACLPolicy(acl *policy.ACLPolicy) Option { return func(hsic *HeadscaleInContainer) { // TODO(kradalby): Move somewhere appropriate hsic.env["HEADSCALE_ACL_POLICY_PATH"] = aclPolicyPath diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 922ced6..006ac0c 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" @@ -57,18 +57,18 @@ func TestSSHOneUserAllToAll(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithSSH()}, hsic.WithACLPolicy( - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:integration-test": {"user1"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, Destinations: []string{"*:*"}, }, }, - SSHs: []hscontrol.SSH{ + SSHs: []policy.SSH{ { Action: "accept", Sources: []string{"group:integration-test"}, @@ -134,18 +134,18 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithSSH()}, hsic.WithACLPolicy( - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:integration-test": {"user1", "user2"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, Destinations: []string{"*:*"}, }, }, - SSHs: []hscontrol.SSH{ + SSHs: []policy.SSH{ { Action: "accept", Sources: []string{"group:integration-test"}, @@ -216,18 +216,18 @@ func TestSSHNoSSHConfigured(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithSSH()}, hsic.WithACLPolicy( - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:integration-test": {"user1"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, Destinations: []string{"*:*"}, }, }, - SSHs: []hscontrol.SSH{}, + SSHs: []policy.SSH{}, }, ), hsic.WithTestName("sshnoneconfigured"), @@ -286,18 +286,18 @@ func TestSSHIsBlockedInACL(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithSSH()}, hsic.WithACLPolicy( - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:integration-test": {"user1"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, Destinations: []string{"*:80"}, }, }, - SSHs: []hscontrol.SSH{ + SSHs: []policy.SSH{ { Action: "accept", Sources: []string{"group:integration-test"}, @@ -364,19 +364,19 @@ func TestSSUserOnlyIsolation(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithSSH()}, hsic.WithACLPolicy( - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:ssh1": {"useracl1"}, "group:ssh2": {"useracl2"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, Destinations: []string{"*:*"}, }, }, - SSHs: []hscontrol.SSH{ + SSHs: []policy.SSH{ { Action: "accept", Sources: []string{"group:ssh1"},