diff --git a/.golangci.yaml b/.golangci.yaml index 153cd7c..56fdc28 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -29,6 +29,7 @@ linters: - wrapcheck - dupl - makezero + - maintidx # We might want to enable this, but it might be a lot of work - cyclop diff --git a/CHANGELOG.md b/CHANGELOG.md index eaf7ade..d72b873 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,33 +1,38 @@ # CHANGELOG -**0.15.0 (2022-xx-xx):** +## 0.15.0 (2022-xx-xx) -**BREAKING**: +**Note:** Take a backup of your database before upgrading. + +### BREAKING - Boundaries between Namespaces has been removed and all nodes can communicate by default [#357](https://github.com/juanfont/headscale/pull/357) - To limit access between nodes, use [ACLs](./docs/acls.md). -**Features**: +### Features - Add support for writing ACL files with YAML [#359](https://github.com/juanfont/headscale/pull/359) - Users can now use emails in ACL's groups [#372](https://github.com/juanfont/headscale/issues/372) - Add shorthand aliases for commands and subcommands [#376](https://github.com/juanfont/headscale/pull/376) -**Changes**: +### Changes - Fix a bug were the same IP could be assigned to multiple hosts if joined in quick succession [#346](https://github.com/juanfont/headscale/pull/346) +- Simplify the code behind registration of machines [#366](https://github.com/juanfont/headscale/pull/366) + - Nodes are now only written to database if they are registrated successfully +- Fix a limitation in the ACLs that prevented users to write rules with `*` as source [#374](https://github.com/juanfont/headscale/issues/374) -**0.14.0 (2022-02-24):** +## 0.14.0 (2022-02-24) -**UPCOMING BREAKING**: -From the **next** version (`0.15.0`), all machines will be able to communicate regardless of +**UPCOMING ### BREAKING +From the **next\*\* version (`0.15.0`), all machines will be able to communicate regardless of if they are in the same namespace. This means that the behaviour currently limited to ACLs will become default. From version `0.15.0`, all limitation of communications must be done with ACLs. This is a part of aligning `headscale`'s behaviour with Tailscale's upstream behaviour. -**BREAKING**: +### BREAKING - ACLs have been rewritten to align with the bevaviour Tailscale Control Panel provides. **NOTE:** This is only active if you use ACLs - Namespaces are now treated as Users @@ -35,17 +40,17 @@ This is a part of aligning `headscale`'s behaviour with Tailscale's upstream beh - Tags should now work correctly and adding a host to Headscale should now reload the rules. - The documentation have a [fictional example](docs/acls.md) that should cover some use cases of the ACLs features -**Features**: +### Features - Add support for configurable mTLS [docs](docs/tls.md#configuring-mutual-tls-authentication-mtls) [#297](https://github.com/juanfont/headscale/pull/297) -**Changes**: +### Changes - Remove dependency on CGO (switch from CGO SQLite to pure Go) [#346](https://github.com/juanfont/headscale/pull/346) **0.13.0 (2022-02-18):** -**Features**: +### Features - Add IPv6 support to the prefix assigned to namespaces - Add API Key support @@ -56,7 +61,7 @@ This is a part of aligning `headscale`'s behaviour with Tailscale's upstream beh - `oidc.domain_map` option has been removed - `strip_email_domain` option has been added (see [config-example.yaml](./config_example.yaml)) -**Changes**: +### Changes - `ip_prefix` is now superseded by `ip_prefixes` in the configuration [#208](https://github.com/juanfont/headscale/pull/208) - Upgrade `tailscale` (1.20.4) and other dependencies to latest [#314](https://github.com/juanfont/headscale/pull/314) @@ -65,35 +70,35 @@ This is a part of aligning `headscale`'s behaviour with Tailscale's upstream beh **0.12.4 (2022-01-29):** -**Changes**: +### Changes - Make gRPC Unix Socket permissions configurable [#292](https://github.com/juanfont/headscale/pull/292) - Trim whitespace before reading Private Key from file [#289](https://github.com/juanfont/headscale/pull/289) - Add new command to generate a private key for `headscale` [#290](https://github.com/juanfont/headscale/pull/290) - Fixed issue where hosts deleted from control server may be written back to the database, as long as they are connected to the control server [#278](https://github.com/juanfont/headscale/pull/278) -**0.12.3 (2022-01-13):** +## 0.12.3 (2022-01-13) -**Changes**: +### Changes - Added Alpine container [#270](https://github.com/juanfont/headscale/pull/270) - Minor updates in dependencies [#271](https://github.com/juanfont/headscale/pull/271) -**0.12.2 (2022-01-11):** +## 0.12.2 (2022-01-11) Happy New Year! -**Changes**: +### Changes - Fix Docker release [#258](https://github.com/juanfont/headscale/pull/258) - Rewrite main docs [#262](https://github.com/juanfont/headscale/pull/262) - Improve Docker docs [#263](https://github.com/juanfont/headscale/pull/263) -**0.12.1 (2021-12-24):** +## 0.12.1 (2021-12-24) (We are skipping 0.12.0 to correct a mishap done weeks ago with the version tagging) -**BREAKING**: +### BREAKING - Upgrade to Tailscale 1.18 [#229](https://github.com/juanfont/headscale/pull/229) - This change requires a new format for private key, private keys are now generated automatically: @@ -101,19 +106,19 @@ Happy New Year! 2. Restart `headscale`, a new key will be generated. 3. Restart all Tailscale clients to fetch the new key -**Changes**: +### Changes - Unify configuration example [#197](https://github.com/juanfont/headscale/pull/197) - Add stricter linting and formatting [#223](https://github.com/juanfont/headscale/pull/223) -**Features**: +### Features - Add gRPC and HTTP API (HTTP API is currently disabled) [#204](https://github.com/juanfont/headscale/pull/204) - Use gRPC between the CLI and the server [#206](https://github.com/juanfont/headscale/pull/206), [#212](https://github.com/juanfont/headscale/pull/212) - Beta OpenID Connect support [#126](https://github.com/juanfont/headscale/pull/126), [#227](https://github.com/juanfont/headscale/pull/227) -**0.11.0 (2021-10-25):** +## 0.11.0 (2021-10-25) -**BREAKING**: +### BREAKING - Make headscale fetch DERP map from URL and file [#196](https://github.com/juanfont/headscale/pull/196) diff --git a/acls_test.go b/acls_test.go index 6cee02c..16861a2 100644 --- a/acls_test.go +++ b/acls_test.go @@ -119,7 +119,6 @@ func (s *Suite) TestValidExpandTagOwnersInUsers(c *check.C) { Name: "testmachine", IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), HostInfo: datatypes.JSON(hostInfo), @@ -163,7 +162,6 @@ func (s *Suite) TestValidExpandTagOwnersInPorts(c *check.C) { Name: "testmachine", IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), HostInfo: datatypes.JSON(hostInfo), @@ -207,7 +205,6 @@ func (s *Suite) TestInvalidTagValidNamespace(c *check.C) { Name: "testmachine", IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), HostInfo: datatypes.JSON(hostInfo), @@ -250,7 +247,6 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) { Name: "webserver", IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), HostInfo: datatypes.JSON(hostInfo), @@ -267,7 +263,6 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) { Name: "user", IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), HostInfo: datatypes.JSON(hostInfo), @@ -361,7 +356,6 @@ func (s *Suite) TestPortNamespace(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: ips, AuthKeyID: uint(pak.ID), @@ -404,7 +398,6 @@ func (s *Suite) TestPortGroup(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: ips, AuthKeyID: uint(pak.ID), diff --git a/api.go b/api.go index bb5495a..3b3b675 100644 --- a/api.go +++ b/api.go @@ -22,7 +22,7 @@ import ( const ( reservedResponseHeaderSize = 4 - RegisterMethodAuthKey = "authKey" + RegisterMethodAuthKey = "authkey" RegisterMethodOIDC = "oidc" RegisterMethodCLI = "cli" ErrRegisterMethodCLIDoesNotSupportExpire = Error( @@ -125,25 +125,50 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { machine, err := h.GetMachineByMachineKey(machineKey) if errors.Is(err, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") - newMachine := Machine{ - Expiry: &time.Time{}, - MachineKey: MachinePublicKeyStripPrefix(machineKey), - Name: req.Hostinfo.Hostname, - } - if err := h.db.Create(&newMachine).Error; err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Could not create row") - machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name). - Inc() + + machineKeyStr := MachinePublicKeyStripPrefix(machineKey) + + // If the machine has AuthKey set, handle registration via PreAuthKeys + if req.Auth.AuthKey != "" { + h.handleAuthKey(ctx, machineKey, req) return } - machine = &newMachine + + // The machine did not have a key to authenticate, which means + // that we rely on a method that calls back some how (OpenID or CLI) + // We create the machine and then keep it around until a callback + // happens + newMachine := Machine{ + MachineKey: machineKeyStr, + Name: req.Hostinfo.Hostname, + NodeKey: NodePublicKeyStripPrefix(req.NodeKey), + LastSeen: &now, + Expiry: &time.Time{}, + } + + if !req.Expiry.IsZero() { + log.Trace(). + Caller(). + Str("machine", req.Hostinfo.Hostname). + Time("expiry", req.Expiry). + Msg("Non-zero expiry time requested") + newMachine.Expiry = &req.Expiry + } + + h.registrationCache.Set( + machineKeyStr, + newMachine, + registerCacheExpiration, + ) + + h.handleMachineRegistrationNew(ctx, machineKey, req) + + return } - if machine.Registered { + // The machine is already registered, so we need to pass through reauth or key update. + if machine != nil { // If the NodeKey stored in headscale is the same as the key presented in a registration // request, then we have a node that is either: // - Trying to log out (sending a expiry in the past) @@ -180,15 +205,6 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { return } - - // If the machine has AuthKey set, handle registration via PreAuthKeys - if req.Auth.AuthKey != "" { - h.handleAuthKey(ctx, machineKey, req, *machine) - - return - } - - h.handleMachineRegistrationNew(ctx, machineKey, req, *machine) } func (h *Headscale) getMapResponse( @@ -402,7 +418,7 @@ func (h *Headscale) handleMachineExpired( Msg("Machine registration has expired. Sending a authurl to register") if registerRequest.Auth.AuthKey != "" { - h.handleAuthKey(ctx, machineKey, registerRequest, machine) + h.handleAuthKey(ctx, machineKey, registerRequest) return } @@ -465,13 +481,12 @@ func (h *Headscale) handleMachineRegistrationNew( ctx *gin.Context, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, - machine Machine, ) { resp := tailcfg.RegisterResponse{} // The machine registration is new, redirect the client to the registration URL log.Debug(). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Msg("The node is sending us a new NodeKey, sending auth url") if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf( @@ -484,24 +499,6 @@ func (h *Headscale) handleMachineRegistrationNew( strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) } - if !registerRequest.Expiry.IsZero() { - log.Trace(). - Caller(). - Str("machine", machine.Name). - Time("expiry", registerRequest.Expiry). - Msg("Non-zero expiry time requested, adding to cache") - h.requestedExpiryCache.Set( - machineKey.String(), - registerRequest.Expiry, - requestedExpiryCacheExpiration, - ) - } - - machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) - - // save the NodeKey - h.db.Save(&machine) - respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). @@ -520,19 +517,21 @@ func (h *Headscale) handleAuthKey( ctx *gin.Context, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, - machine Machine, ) { + machineKeyStr := MachinePublicKeyStripPrefix(machineKey) + log.Debug(). Str("func", "handleAuthKey"). Str("machine", registerRequest.Hostinfo.Hostname). Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) resp := tailcfg.RegisterResponse{} + pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey) if err != nil { log.Error(). Caller(). Str("func", "handleAuthKey"). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Err(err). Msg("Failed authentication via AuthKey") resp.MachineAuthorized = false @@ -541,76 +540,66 @@ func (h *Headscale) handleAuthKey( log.Error(). Caller(). Str("func", "handleAuthKey"). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Err(err). Msg("Cannot encode message") ctx.String(http.StatusInternalServerError, "") - machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). Inc() return } + ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) log.Error(). Caller(). Str("func", "handleAuthKey"). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Msg("Failed authentication via AuthKey") - machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). Inc() return } - if machine.isRegistered() { - log.Trace(). - Caller(). - Str("machine", machine.Name). - Msg("machine already registered, reauthenticating") + log.Debug(). + Str("func", "handleAuthKey"). + Str("machine", registerRequest.Hostinfo.Hostname). + Msg("Authentication key was valid, proceeding to acquire IP addresses") - h.RefreshMachine(&machine, registerRequest.Expiry) - } else { - log.Debug(). - Str("func", "handleAuthKey"). - Str("machine", machine.Name). - Msg("Authentication key was valid, proceeding to acquire IP addresses") + nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey) + now := time.Now().UTC() - h.ipAllocationMutex.Lock() - - ips, err := h.getAvailableIPs() - if err != nil { - log.Error(). - Caller(). - Str("func", "handleAuthKey"). - Str("machine", machine.Name). - Msg("Failed to find an available IP address") - machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). - Inc() - - return - } - log.Info(). - Str("func", "handleAuthKey"). - Str("machine", machine.Name). - Str("ips", strings.Join(ips.ToStringSlice(), ",")). - Msgf("Assigning %s to %s", strings.Join(ips.ToStringSlice(), ","), machine.Name) - - machine.Expiry = ®isterRequest.Expiry - machine.AuthKeyID = uint(pak.ID) - machine.IPAddresses = ips - machine.NamespaceID = pak.NamespaceID - - machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) - // we update it just in case - machine.Registered = true - machine.RegisterMethod = RegisterMethodAuthKey - h.db.Save(&machine) - - h.ipAllocationMutex.Unlock() + machineToRegister := Machine{ + Name: registerRequest.Hostinfo.Hostname, + NamespaceID: pak.Namespace.ID, + MachineKey: machineKeyStr, + RegisterMethod: RegisterMethodAuthKey, + Expiry: ®isterRequest.Expiry, + NodeKey: nodeKey, + LastSeen: &now, + AuthKeyID: uint(pak.ID), } - pak.Used = true - h.db.Save(&pak) + machine, err := h.RegisterMachine( + machineToRegister, + ) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("could not register machine") + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). + Inc() + ctx.String( + http.StatusInternalServerError, + "could not register machine", + ) + + return + } + + h.UsePreAuthKey(pak) resp.MachineAuthorized = true resp.User = *pak.Namespace.toUser() @@ -619,21 +608,21 @@ func (h *Headscale) handleAuthKey( log.Error(). Caller(). Str("func", "handleAuthKey"). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Err(err). Msg("Cannot encode message") - machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). Inc() ctx.String(http.StatusInternalServerError, "Extremely sad!") return } - machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name). + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name). Inc() ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) log.Info(). Str("func", "handleAuthKey"). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")). Msg("Successfully authenticated via AuthKey") } diff --git a/app.go b/app.go index eda670e..5818509 100644 --- a/app.go +++ b/app.go @@ -55,8 +55,8 @@ const ( HTTPReadTimeout = 30 * time.Second privateKeyFileMode = 0o600 - requestedExpiryCacheExpiration = time.Minute * 5 - requestedExpiryCacheCleanupInterval = time.Minute * 10 + registerCacheExpiration = time.Minute * 15 + registerCacheCleanup = time.Minute * 20 errUnsupportedDatabase = Error("unsupported DB") errUnsupportedLetsEncryptChallengeType = Error( @@ -148,11 +148,10 @@ type Headscale struct { lastStateChange sync.Map - oidcProvider *oidc.Provider - oauth2Config *oauth2.Config - oidcStateCache *cache.Cache + oidcProvider *oidc.Provider + oauth2Config *oauth2.Config - requestedExpiryCache *cache.Cache + registrationCache *cache.Cache ipAllocationMutex sync.Mutex } @@ -202,18 +201,18 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, errUnsupportedDatabase } - requestedExpiryCache := cache.New( - requestedExpiryCacheExpiration, - requestedExpiryCacheCleanupInterval, + registrationCache := cache.New( + registerCacheExpiration, + registerCacheCleanup, ) app := Headscale{ - cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, - privateKey: privKey, - aclRules: tailcfg.FilterAllowAll, // default allowall - requestedExpiryCache: requestedExpiryCache, + cfg: cfg, + dbType: cfg.DBtype, + dbString: dbString, + privateKey: privKey, + aclRules: tailcfg.FilterAllowAll, // default allowall + registrationCache: registrationCache, } err = app.initDB() diff --git a/app_test.go b/app_test.go index 53c703a..96036a1 100644 --- a/app_test.go +++ b/app_test.go @@ -5,7 +5,6 @@ import ( "os" "testing" - "github.com/patrickmn/go-cache" "gopkg.in/check.v1" "inet.af/netaddr" ) @@ -50,10 +49,6 @@ func (s *Suite) ResetDB(c *check.C) { cfg: cfg, dbType: "sqlite3", dbString: tmpDir + "/headscale_test.db", - requestedExpiryCache: cache.New( - requestedExpiryCacheExpiration, - requestedExpiryCacheCleanupInterval, - ), } err = app.initDB() if err != nil { diff --git a/cli_test.go b/cli_test.go deleted file mode 100644 index 71f2bea..0000000 --- a/cli_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package headscale - -import ( - "time" - - "gopkg.in/check.v1" - "inet.af/netaddr" -) - -func (s *Suite) TestRegisterMachine(c *check.C) { - namespace, err := app.CreateNamespace("test") - c.Assert(err, check.IsNil) - - now := time.Now().UTC() - - machine := Machine{ - ID: 0, - MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", - NodeKey: "bar", - DiscoKey: "faa", - Name: "testmachine", - NamespaceID: namespace.ID, - IPAddresses: []netaddr.IP{netaddr.MustParseIP("10.0.0.1")}, - Expiry: &now, - } - err = app.db.Save(&machine).Error - c.Assert(err, check.IsNil) - - _, err = app.GetMachine(namespace.Name, machine.Name) - c.Assert(err, check.IsNil) - - machineAfterRegistering, err := app.RegisterMachine( - machine.MachineKey, - namespace.Name, - ) - c.Assert(err, check.IsNil) - c.Assert(machineAfterRegistering.Registered, check.Equals, true) - - _, err = machineAfterRegistering.GetHostInfo() - c.Assert(err, check.IsNil) -} diff --git a/db.go b/db.go index 8db7bc6..e87cb6d 100644 --- a/db.go +++ b/db.go @@ -5,6 +5,7 @@ import ( "time" "github.com/glebarez/sqlite" + "github.com/rs/zerolog/log" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -34,6 +35,38 @@ func (h *Headscale) initDB() error { _ = db.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses") + // If the Machine table has a column for registered, + // find all occourences of "false" and drop them. Then + // remove the column. + if db.Migrator().HasColumn(&Machine{}, "registered") { + log.Info(). + Msg(`Database has legacy "registered" column in machine, removing...`) + + machines := Machines{} + if err := h.db.Not("registered").Find(&machines).Error; err != nil { + log.Error().Err(err).Msg("Error accessing db") + } + + for _, machine := range machines { + log.Info(). + Str("machine", machine.Name). + Str("machine_key", machine.MachineKey). + Msg("Deleting unregistered machine") + if err := h.db.Delete(&Machine{}, machine.ID).Error; err != nil { + log.Error(). + Err(err). + Str("machine", machine.Name). + Str("machine_key", machine.MachineKey). + Msg("Error deleting unregistered machine") + } + } + + err := db.Migrator().DropColumn(&Machine{}, "registered") + if err != nil { + log.Error().Err(err).Msg("Error dropping registered column") + } + } + err = db.AutoMigrate(&Machine{}) if err != nil { return err diff --git a/dns_test.go b/dns_test.go index 23a34ca..5fd30d3 100644 --- a/dns_test.go +++ b/dns_test.go @@ -164,7 +164,6 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Name: "test_get_shared_nodes_1", NamespaceID: namespaceShared1.ID, Namespace: *namespaceShared1, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), @@ -182,7 +181,6 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Name: "test_get_shared_nodes_2", NamespaceID: namespaceShared2.ID, Namespace: *namespaceShared2, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), @@ -200,7 +198,6 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Name: "test_get_shared_nodes_3", NamespaceID: namespaceShared3.ID, Namespace: *namespaceShared3, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), @@ -218,7 +215,6 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Name: "test_get_shared_nodes_4", NamespaceID: namespaceShared1.ID, Namespace: *namespaceShared1, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")}, AuthKeyID: uint(PreAuthKey2InShared1.ID), @@ -311,7 +307,6 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Name: "test_get_shared_nodes_1", NamespaceID: namespaceShared1.ID, Namespace: *namespaceShared1, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), @@ -329,7 +324,6 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Name: "test_get_shared_nodes_2", NamespaceID: namespaceShared2.ID, Namespace: *namespaceShared2, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), @@ -347,7 +341,6 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Name: "test_get_shared_nodes_3", NamespaceID: namespaceShared3.ID, Namespace: *namespaceShared3, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), @@ -365,7 +358,6 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Name: "test_get_shared_nodes_4", NamespaceID: namespaceShared1.ID, Namespace: *namespaceShared1, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")}, AuthKeyID: uint(preAuthKey2InShared1.ID), diff --git a/gen/go/headscale/v1/machine.pb.go b/gen/go/headscale/v1/machine.pb.go index 9a5064f..b25db29 100644 --- a/gen/go/headscale/v1/machine.pb.go +++ b/gen/go/headscale/v1/machine.pb.go @@ -85,13 +85,12 @@ type Machine struct { IpAddresses []string `protobuf:"bytes,5,rep,name=ip_addresses,json=ipAddresses,proto3" json:"ip_addresses,omitempty"` Name string `protobuf:"bytes,6,opt,name=name,proto3" json:"name,omitempty"` Namespace *Namespace `protobuf:"bytes,7,opt,name=namespace,proto3" json:"namespace,omitempty"` - Registered bool `protobuf:"varint,8,opt,name=registered,proto3" json:"registered,omitempty"` - RegisterMethod RegisterMethod `protobuf:"varint,9,opt,name=register_method,json=registerMethod,proto3,enum=headscale.v1.RegisterMethod" json:"register_method,omitempty"` - LastSeen *timestamppb.Timestamp `protobuf:"bytes,10,opt,name=last_seen,json=lastSeen,proto3" json:"last_seen,omitempty"` - LastSuccessfulUpdate *timestamppb.Timestamp `protobuf:"bytes,11,opt,name=last_successful_update,json=lastSuccessfulUpdate,proto3" json:"last_successful_update,omitempty"` - Expiry *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=expiry,proto3" json:"expiry,omitempty"` - PreAuthKey *PreAuthKey `protobuf:"bytes,13,opt,name=pre_auth_key,json=preAuthKey,proto3" json:"pre_auth_key,omitempty"` - CreatedAt *timestamppb.Timestamp `protobuf:"bytes,14,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + LastSeen *timestamppb.Timestamp `protobuf:"bytes,8,opt,name=last_seen,json=lastSeen,proto3" json:"last_seen,omitempty"` + LastSuccessfulUpdate *timestamppb.Timestamp `protobuf:"bytes,9,opt,name=last_successful_update,json=lastSuccessfulUpdate,proto3" json:"last_successful_update,omitempty"` + Expiry *timestamppb.Timestamp `protobuf:"bytes,10,opt,name=expiry,proto3" json:"expiry,omitempty"` + PreAuthKey *PreAuthKey `protobuf:"bytes,11,opt,name=pre_auth_key,json=preAuthKey,proto3" json:"pre_auth_key,omitempty"` + CreatedAt *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + RegisterMethod RegisterMethod `protobuf:"varint,13,opt,name=register_method,json=registerMethod,proto3,enum=headscale.v1.RegisterMethod" json:"register_method,omitempty"` } func (x *Machine) Reset() { @@ -175,20 +174,6 @@ func (x *Machine) GetNamespace() *Namespace { return nil } -func (x *Machine) GetRegistered() bool { - if x != nil { - return x.Registered - } - return false -} - -func (x *Machine) GetRegisterMethod() RegisterMethod { - if x != nil { - return x.RegisterMethod - } - return RegisterMethod_REGISTER_METHOD_UNSPECIFIED -} - func (x *Machine) GetLastSeen() *timestamppb.Timestamp { if x != nil { return x.LastSeen @@ -224,6 +209,13 @@ func (x *Machine) GetCreatedAt() *timestamppb.Timestamp { return nil } +func (x *Machine) GetRegisterMethod() RegisterMethod { + if x != nil { + return x.RegisterMethod + } + return RegisterMethod_REGISTER_METHOD_UNSPECIFIED +} + type RegisterMachineRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -822,7 +814,7 @@ var file_headscale_v1_machine_proto_rawDesc = []byte{ 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1d, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x70, 0x72, 0x65, 0x61, 0x75, 0x74, 0x68, 0x6b, - 0x65, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xfd, 0x04, 0x0a, 0x07, 0x4d, 0x61, 0x63, + 0x65, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xdd, 0x04, 0x0a, 0x07, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6d, 0x61, 0x63, 0x68, 0x69, @@ -836,33 +828,31 @@ var file_headscale_v1_machine_proto_rawDesc = []byte{ 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x35, 0x0a, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, - 0x52, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x72, - 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x65, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x0a, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x65, 0x64, 0x12, 0x45, 0x0a, 0x0f, 0x72, - 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x09, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, - 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x68, - 0x6f, 0x64, 0x52, 0x0e, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x68, - 0x6f, 0x64, 0x12, 0x37, 0x0a, 0x09, 0x6c, 0x61, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x65, 0x6e, 0x18, - 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, - 0x70, 0x52, 0x08, 0x6c, 0x61, 0x73, 0x74, 0x53, 0x65, 0x65, 0x6e, 0x12, 0x50, 0x0a, 0x16, 0x6c, - 0x61, 0x73, 0x74, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x5f, 0x75, - 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, - 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, - 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x14, 0x6c, 0x61, 0x73, 0x74, 0x53, 0x75, 0x63, - 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x32, 0x0a, - 0x06, 0x65, 0x78, 0x70, 0x69, 0x72, 0x79, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x06, 0x65, 0x78, 0x70, 0x69, 0x72, - 0x79, 0x12, 0x3a, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6b, 0x65, - 0x79, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, - 0x79, 0x52, 0x0a, 0x70, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x12, 0x39, 0x0a, - 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x0e, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, - 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x22, 0x48, 0x0a, 0x16, 0x52, 0x65, 0x67, 0x69, + 0x52, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x37, 0x0a, 0x09, 0x6c, + 0x61, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x08, 0x6c, 0x61, 0x73, 0x74, + 0x53, 0x65, 0x65, 0x6e, 0x12, 0x50, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x5f, 0x73, 0x75, 0x63, + 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x5f, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x09, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, + 0x52, 0x14, 0x6c, 0x61, 0x73, 0x74, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x32, 0x0a, 0x06, 0x65, 0x78, 0x70, 0x69, 0x72, 0x79, + 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, + 0x6d, 0x70, 0x52, 0x06, 0x65, 0x78, 0x70, 0x69, 0x72, 0x79, 0x12, 0x3a, 0x0a, 0x0c, 0x70, 0x72, + 0x65, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x18, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, + 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x52, 0x0a, 0x70, 0x72, 0x65, 0x41, + 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x12, 0x39, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, + 0x64, 0x5f, 0x61, 0x74, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, + 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, + 0x74, 0x12, 0x45, 0x0a, 0x0f, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x5f, 0x6d, 0x65, + 0x74, 0x68, 0x6f, 0x64, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, 0x2e, 0x68, 0x65, 0x61, + 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, + 0x65, 0x72, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x52, 0x0e, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, + 0x65, 0x72, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x22, 0x48, 0x0a, 0x16, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, @@ -962,12 +952,12 @@ var file_headscale_v1_machine_proto_goTypes = []interface{}{ } var file_headscale_v1_machine_proto_depIdxs = []int32{ 14, // 0: headscale.v1.Machine.namespace:type_name -> headscale.v1.Namespace - 0, // 1: headscale.v1.Machine.register_method:type_name -> headscale.v1.RegisterMethod - 15, // 2: headscale.v1.Machine.last_seen:type_name -> google.protobuf.Timestamp - 15, // 3: headscale.v1.Machine.last_successful_update:type_name -> google.protobuf.Timestamp - 15, // 4: headscale.v1.Machine.expiry:type_name -> google.protobuf.Timestamp - 16, // 5: headscale.v1.Machine.pre_auth_key:type_name -> headscale.v1.PreAuthKey - 15, // 6: headscale.v1.Machine.created_at:type_name -> google.protobuf.Timestamp + 15, // 1: headscale.v1.Machine.last_seen:type_name -> google.protobuf.Timestamp + 15, // 2: headscale.v1.Machine.last_successful_update:type_name -> google.protobuf.Timestamp + 15, // 3: headscale.v1.Machine.expiry:type_name -> google.protobuf.Timestamp + 16, // 4: headscale.v1.Machine.pre_auth_key:type_name -> headscale.v1.PreAuthKey + 15, // 5: headscale.v1.Machine.created_at:type_name -> google.protobuf.Timestamp + 0, // 6: headscale.v1.Machine.register_method:type_name -> headscale.v1.RegisterMethod 1, // 7: headscale.v1.RegisterMachineResponse.machine:type_name -> headscale.v1.Machine 1, // 8: headscale.v1.GetMachineResponse.machine:type_name -> headscale.v1.Machine 1, // 9: headscale.v1.ExpireMachineResponse.machine:type_name -> headscale.v1.Machine diff --git a/gen/openapiv2/headscale/v1/headscale.swagger.json b/gen/openapiv2/headscale/v1/headscale.swagger.json index db348f4..8d808f9 100644 --- a/gen/openapiv2/headscale/v1/headscale.swagger.json +++ b/gen/openapiv2/headscale/v1/headscale.swagger.json @@ -885,12 +885,6 @@ "namespace": { "$ref": "#/definitions/v1Namespace" }, - "registered": { - "type": "boolean" - }, - "registerMethod": { - "$ref": "#/definitions/v1RegisterMethod" - }, "lastSeen": { "type": "string", "format": "date-time" @@ -909,6 +903,9 @@ "createdAt": { "type": "string", "format": "date-time" + }, + "registerMethod": { + "$ref": "#/definitions/v1RegisterMethod" } } }, diff --git a/grpcv1.go b/grpcv1.go index 60e181d..1c02290 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -159,9 +159,11 @@ func (api headscaleV1APIServer) RegisterMachine( Str("namespace", request.GetNamespace()). Str("machine_key", request.GetKey()). Msg("Registering machine") - machine, err := api.h.RegisterMachine( + + machine, err := api.h.RegisterMachineFromAuthCallback( request.GetKey(), request.GetNamespace(), + RegisterMethodCLI, ) if err != nil { return nil, err @@ -398,11 +400,11 @@ func (api headscaleV1APIServer) DebugCreateMachine( HostInfo: datatypes.JSON(hostinfoJson), } - // log.Trace().Caller().Interface("machine", newMachine).Msg("") - - if err := api.h.db.Create(&newMachine).Error; err != nil { - return nil, err - } + api.h.registrationCache.Set( + request.GetKey(), + newMachine, + registerCacheExpiration, + ) return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil } diff --git a/integration_cli_test.go b/integration_cli_test.go index aae80cb..7ce0758 100644 --- a/integration_cli_test.go +++ b/integration_cli_test.go @@ -621,12 +621,6 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() { assert.Equal(s.T(), "machine-4", listAll[3].Name) assert.Equal(s.T(), "machine-5", listAll[4].Name) - assert.True(s.T(), listAll[0].Registered) - assert.True(s.T(), listAll[1].Registered) - assert.True(s.T(), listAll[2].Registered) - assert.True(s.T(), listAll[3].Registered) - assert.True(s.T(), listAll[4].Registered) - otherNamespaceMachineKeys := []string{ "b5b444774186d4217adcec407563a1223929465ee2c68a4da13af0d0185b4f8e", "dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584", @@ -710,9 +704,6 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() { assert.Equal(s.T(), "otherNamespace-machine-1", listAllWithotherNamespace[5].Name) assert.Equal(s.T(), "otherNamespace-machine-2", listAllWithotherNamespace[6].Name) - assert.True(s.T(), listAllWithotherNamespace[5].Registered) - assert.True(s.T(), listAllWithotherNamespace[6].Registered) - // Test list all nodes after added otherNamespace listOnlyotherNamespaceMachineNamespaceResult, err := ExecuteCommand( &s.headscale, @@ -752,9 +743,6 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() { listOnlyotherNamespaceMachineNamespace[1].Name, ) - assert.True(s.T(), listOnlyotherNamespaceMachineNamespace[0].Registered) - assert.True(s.T(), listOnlyotherNamespaceMachineNamespace[1].Registered) - // Delete a machines _, err = ExecuteCommand( &s.headscale, @@ -979,7 +967,6 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() { assert.Equal(s.T(), uint64(1), machine.Id) assert.Equal(s.T(), "route-machine", machine.Name) - assert.True(s.T(), machine.Registered) listAllResult, err := ExecuteCommand( &s.headscale, diff --git a/machine.go b/machine.go index 3106d0b..dfe099d 100644 --- a/machine.go +++ b/machine.go @@ -20,11 +20,14 @@ import ( ) const ( - errMachineNotFound = Error("machine not found") - errMachineAlreadyRegistered = Error("machine already registered") - errMachineRouteIsNotAvailable = Error("route is not available on machine") - errMachineAddressesInvalid = Error("failed to parse machine addresses") - errHostnameTooLong = Error("Hostname too long") + errMachineNotFound = Error("machine not found") + errMachineRouteIsNotAvailable = Error("route is not available on machine") + errMachineAddressesInvalid = Error("failed to parse machine addresses") + errMachineNotFoundRegistrationCache = Error( + "machine not found in registration cache", + ) + errCouldNotConvertMachineInterface = Error("failed to convert machine interface") + errHostnameTooLong = Error("Hostname too long") ) const ( @@ -42,15 +45,18 @@ type Machine struct { NamespaceID uint Namespace Namespace `gorm:"foreignKey:NamespaceID"` - Registered bool // temp RegisterMethod string - AuthKeyID uint - AuthKey *PreAuthKey + + // TODO(kradalby): This seems like irrelevant information? + AuthKeyID uint + AuthKey *PreAuthKey LastSeen *time.Time LastSuccessfulUpdate *time.Time Expiry *time.Time + // TODO(kradalby): Figure out a way to use tailcfg datatypes + // here and have gorm serialise them. HostInfo datatypes.JSON Endpoints datatypes.JSON EnabledRoutes datatypes.JSON @@ -65,11 +71,6 @@ type ( MachinesP []*Machine ) -// For the time being this method is rather naive. -func (machine Machine) isRegistered() bool { - return machine.Registered -} - type MachineAddresses []netaddr.IP func (ma MachineAddresses) ToStringSlice() []string { @@ -116,7 +117,7 @@ func (machine Machine) isExpired() bool { // If Expiry is not set, the client has not indicated that // it wants an expiry time, it is therefor considered // to mean "not expired" - if machine.Expiry.IsZero() { + if machine.Expiry == nil || machine.Expiry.IsZero() { return false } @@ -173,6 +174,12 @@ func getFilteredByACLPeers( machine.IPAddresses.ToStringSlice(), peer.IPAddresses.ToStringSlice(), ) || // match source and destination + matchSourceAndDestinationWithRule( + rule.SrcIPs, + dst, + peer.IPAddresses.ToStringSlice(), + machine.IPAddresses.ToStringSlice(), + ) || // match return path matchSourceAndDestinationWithRule( rule.SrcIPs, dst, @@ -182,9 +189,21 @@ func getFilteredByACLPeers( matchSourceAndDestinationWithRule( rule.SrcIPs, dst, + []string{"*"}, + []string{"*"}, + ) || // match source and all destination + matchSourceAndDestinationWithRule( + rule.SrcIPs, + dst, + []string{"*"}, peer.IPAddresses.ToStringSlice(), + ) || // match source and all destination + matchSourceAndDestinationWithRule( + rule.SrcIPs, + dst, + []string{"*"}, machine.IPAddresses.ToStringSlice(), - ) { // match return path + ) { // match all sources and source peers[peer.ID] = peer } } @@ -214,7 +233,7 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) { Msg("Finding direct peers") machines := Machines{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where("machine_key <> ? AND registered", + if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where("machine_key <> ?", machine.MachineKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") @@ -277,7 +296,7 @@ func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) { } for _, peer := range peers { - if peer.isRegistered() && !peer.isExpired() { + if !peer.isExpired() { validPeers = append(validPeers, peer) } } @@ -366,8 +385,6 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) { // DeleteMachine softs deletes a Machine from the database. func (h *Headscale) DeleteMachine(machine *Machine) error { - machine.Registered = false - h.db.Save(&machine) // we mark it as unregistered, just in case if err := h.db.Delete(&machine).Error; err != nil { return err } @@ -635,7 +652,7 @@ func (machine Machine) toNode( LastSeen: machine.LastSeen, KeepAlive: true, - MachineAuthorized: machine.Registered, + MachineAuthorized: !machine.isExpired(), Capabilities: []string{tailcfg.CapabilityFileSharing}, } @@ -653,8 +670,6 @@ func (machine *Machine) toProto() *v1.Machine { Name: machine.Name, Namespace: machine.Namespace.toProto(), - Registered: machine.Registered, - // TODO(kradalby): Implement register method enum converter // RegisterMethod: , @@ -682,74 +697,50 @@ func (machine *Machine) toProto() *v1.Machine { return machineProto } -// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. -func (h *Headscale) RegisterMachine( +func (h *Headscale) RegisterMachineFromAuthCallback( machineKeyStr string, namespaceName string, + registrationMethod string, ) (*Machine, error) { - namespace, err := h.GetNamespace(namespaceName) - if err != nil { - return nil, err - } + if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { + if registrationMachine, ok := machineInterface.(Machine); ok { + namespace, err := h.GetNamespace(namespaceName) + if err != nil { + return nil, fmt.Errorf( + "failed to find namespace in register machine from auth callback, %w", + err, + ) + } - var machineKey key.MachinePublic - err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) - if err != nil { - return nil, err - } + registrationMachine.NamespaceID = namespace.ID + registrationMachine.RegisterMethod = registrationMethod - log.Trace(). - Caller(). - Str("machine_key_str", machineKeyStr). - Str("machine_key", machineKey.String()). - Msg("Registering machine") + machine, err := h.RegisterMachine( + registrationMachine, + ) - machine, err := h.GetMachineByMachineKey(machineKey) - if err != nil { - return nil, err - } - - // TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set - // This means that if a user is to slow with register a machine, it will possibly not - // have the correct expiry. - requestedTime := time.Time{} - if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found { - log.Trace(). - Caller(). - Str("machine", machine.Name). - Msg("Expiry time found in cache, assigning to node") - if reqTime, ok := requestedTimeIf.(time.Time); ok { - requestedTime = reqTime + return machine, err + } else { + return nil, errCouldNotConvertMachineInterface } } - if machine.isRegistered() { - log.Trace(). - Caller(). - Str("machine", machine.Name). - Msg("machine already registered, reauthenticating") + return nil, errMachineNotFoundRegistrationCache +} - h.RefreshMachine(machine, requestedTime) - - return machine, nil - } +// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. +func (h *Headscale) RegisterMachine(machine Machine, +) (*Machine, error) { + log.Trace(). + Caller(). + Str("machine_key", machine.MachineKey). + Msg("Registering machine") log.Trace(). Caller(). Str("machine", machine.Name). Msg("Attempting to register machine") - if machine.isRegistered() { - err := errMachineAlreadyRegistered - log.Error(). - Caller(). - Err(err). - Str("machine", machine.Name). - Msg("Attempting to register machine") - - return nil, err - } - h.ipAllocationMutex.Lock() defer h.ipAllocationMutex.Unlock() @@ -764,17 +755,8 @@ func (h *Headscale) RegisterMachine( return nil, err } - log.Trace(). - Caller(). - Str("machine", machine.Name). - Str("ip", strings.Join(ips.ToStringSlice(), ",")). - Msg("Found IP for host") - machine.IPAddresses = ips - machine.NamespaceID = namespace.ID - machine.Registered = true - machine.RegisterMethod = RegisterMethodCLI - machine.Expiry = &requestedTime + h.db.Save(&machine) log.Trace(). @@ -783,7 +765,7 @@ func (h *Headscale) RegisterMachine( Str("ip", strings.Join(ips.ToStringSlice(), ",")). Msg("Machine registered with the database") - return machine, nil + return &machine, nil } func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { diff --git a/machine_test.go b/machine_test.go index e9c91f8..b5a462f 100644 --- a/machine_test.go +++ b/machine_test.go @@ -29,7 +29,6 @@ func (s *Suite) TestGetMachine(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } @@ -59,7 +58,6 @@ func (s *Suite) TestGetMachineByID(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } @@ -82,7 +80,6 @@ func (s *Suite) TestDeleteMachine(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(1), } @@ -105,7 +102,6 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) { DiscoKey: "faa", Name: "testmachine3", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(1), } @@ -136,7 +132,6 @@ func (s *Suite) TestListPeers(c *check.C) { DiscoKey: "faa" + strconv.Itoa(index), Name: "testmachine" + strconv.Itoa(index), NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } @@ -188,7 +183,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { }, Name: "testmachine" + strconv.Itoa(index), NamespaceID: stor[index%2].namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(stor[index%2].key.ID), } @@ -258,7 +252,6 @@ func (s *Suite) TestExpireMachine(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), Expiry: &time.Time{}, @@ -296,6 +289,7 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { } } +// nolint func Test_getFilteredByACLPeers(t *testing.T) { type args struct { machines []Machine @@ -443,7 +437,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, }, machine: &Machine{ // current machine - ID: 1, + ID: 2, IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "marc"}, }, @@ -456,6 +450,208 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, }, }, + { + name: "rules allows all hosts to reach one destination", + args: args{ + machines: []Machine{ // list of all machines in the database + { + ID: 1, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.1"), + }, + Namespace: Namespace{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.2"), + }, + Namespace: Namespace{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.3"), + }, + Namespace: Namespace{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &Machine{ // current machine + ID: 1, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.1"), + }, + Namespace: Namespace{Name: "joe"}, + }, + }, + want: Machines{ + { + ID: 2, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.2"), + }, + Namespace: Namespace{Name: "marc"}, + }, + }, + }, + { + name: "rules allows all hosts to reach one destination, destination can reach all hosts", + args: args{ + machines: []Machine{ // list of all machines in the database + { + ID: 1, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.1"), + }, + Namespace: Namespace{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.2"), + }, + Namespace: Namespace{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.3"), + }, + Namespace: Namespace{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &Machine{ // current machine + ID: 2, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.2"), + }, + Namespace: Namespace{Name: "marc"}, + }, + }, + want: Machines{ + { + ID: 1, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.1"), + }, + Namespace: Namespace{Name: "joe"}, + }, + { + ID: 3, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.3"), + }, + Namespace: Namespace{Name: "mickael"}, + }, + }, + }, + { + name: "rule allows all hosts to reach all destinations", + args: args{ + machines: []Machine{ // list of all machines in the database + { + ID: 1, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.1"), + }, + Namespace: Namespace{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.2"), + }, + Namespace: Namespace{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.3"), + }, + Namespace: Namespace{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + machine: &Machine{ // current machine + ID: 2, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, + Namespace: Namespace{Name: "marc"}, + }, + }, + want: Machines{ + { + ID: 1, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.1"), + }, + Namespace: Namespace{Name: "joe"}, + }, + { + ID: 3, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, + Namespace: Namespace{Name: "mickael"}, + }, + }, + }, + { + name: "without rule all communications are forbidden", + args: args{ + machines: []Machine{ // list of all machines in the database + { + ID: 1, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.1"), + }, + Namespace: Namespace{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.2"), + }, + Namespace: Namespace{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: MachineAddresses{ + netaddr.MustParseIP("100.64.0.3"), + }, + Namespace: Namespace{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + }, + machine: &Machine{ // current machine + ID: 2, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, + Namespace: Namespace{Name: "marc"}, + }, + }, + want: Machines{}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/namespaces_test.go b/namespaces_test.go index 585ba93..1710f7d 100644 --- a/namespaces_test.go +++ b/namespaces_test.go @@ -54,7 +54,6 @@ func (s *Suite) TestDestroyNamespaceErrors(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } @@ -146,7 +145,6 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { Name: "test_get_shared_nodes_1", NamespaceID: namespaceShared1.ID, Namespace: *namespaceShared1, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")}, AuthKeyID: uint(preAuthKeyShared1.ID), @@ -164,7 +162,6 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { Name: "test_get_shared_nodes_2", NamespaceID: namespaceShared2.ID, Namespace: *namespaceShared2, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")}, AuthKeyID: uint(preAuthKeyShared2.ID), @@ -182,7 +179,6 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { Name: "test_get_shared_nodes_3", NamespaceID: namespaceShared3.ID, Namespace: *namespaceShared3, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")}, AuthKeyID: uint(preAuthKeyShared3.ID), @@ -200,7 +196,6 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { Name: "test_get_shared_nodes_4", NamespaceID: namespaceShared1.ID, Namespace: *namespaceShared1, - Registered: true, RegisterMethod: RegisterMethodAuthKey, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")}, AuthKeyID: uint(preAuthKey2Shared1.ID), diff --git a/oidc.go b/oidc.go index edea32a..987bc8b 100644 --- a/oidc.go +++ b/oidc.go @@ -10,21 +10,16 @@ import ( "html/template" "net/http" "strings" - "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" - "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "golang.org/x/oauth2" - "gorm.io/gorm" "tailscale.com/types/key" ) const ( - oidcStateCacheExpiration = time.Minute * 5 - oidcStateCacheCleanupInterval = time.Minute * 10 - randomByteSize = 16 + randomByteSize = 16 ) type IDTokenClaims struct { @@ -61,14 +56,6 @@ func (h *Headscale) initOIDC() error { } } - // init the state cache if it hasn't been already - if h.oidcStateCache == nil { - h.oidcStateCache = cache.New( - oidcStateCacheExpiration, - oidcStateCacheCleanupInterval, - ) - } - return nil } @@ -101,7 +88,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { stateStr := hex.EncodeToString(randomBlob)[:32] // place the machine key into the state cache, so it can be retrieved later - h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration) + h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) authURL := h.oauth2Config.AuthCodeURL(stateStr) log.Debug().Msgf("Redirecting to %s for authentication", authURL) @@ -125,7 +112,6 @@ var oidcCallbackTemplate = template.Must( `), ) -// TODO: Why is the entire machine registration logic duplicated here? // OIDCCallback handles the callback from the OIDC endpoint // Retrieves the mkey from the state cache and adds the machine to the users email namespace // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities @@ -197,7 +183,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { } // retrieve machinekey from state cache - machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state) + machineKeyIf, machineKeyFound := h.registrationCache.Get(state) if !machineKeyFound { log.Error(). @@ -207,10 +193,12 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - machineKeyStr, machineKeyOK := machineKeyIf.(string) + machineKeyFromCache, machineKeyOK := machineKeyIf.(string) var machineKey key.MachinePublic - err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) + err = machineKey.UnmarshalText( + []byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)), + ) if err != nil { log.Error(). Msg("could not parse machine public key") @@ -229,33 +217,19 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - // TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set - requestedTime := time.Time{} - if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found { - if reqTime, ok := requestedTimeIf.(time.Time); ok { - requestedTime = reqTime - } - } + // retrieve machine information if it exist + // The error is not important, because if it does not + // exist, then this is a new machine and we will move + // on to registration. + machine, _ := h.GetMachineByMachineKey(machineKey) - // retrieve machine information - machine, err := h.GetMachineByMachineKey(machineKey) - if err != nil { - log.Error().Msg("machine key not found in database") - ctx.String( - http.StatusInternalServerError, - "could not get machine info from database", - ) - - return - } - - if machine.isRegistered() { + if machine != nil { log.Trace(). Caller(). Str("machine", machine.Name). Msg("machine already registered, reauthenticating") - h.RefreshMachine(machine, requestedTime) + h.RefreshMachine(machine, *machine.Expiry) var content bytes.Buffer if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ @@ -279,8 +253,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - now := time.Now().UTC() - namespaceName, err := NormalizeNamespaceName( claims.Email, h.cfg.OIDC.StripEmaildomain, @@ -294,61 +266,58 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } + // register the machine if it's new - if !machine.Registered { - log.Debug().Msg("Registering new machine after successful callback") + log.Debug().Msg("Registering new machine after successful callback") - namespace, err := h.GetNamespace(namespaceName) - if errors.Is(err, gorm.ErrRecordNotFound) { - namespace, err = h.CreateNamespace(namespaceName) + namespace, err := h.GetNamespace(namespaceName) + if errors.Is(err, errNamespaceNotFound) { + namespace, err = h.CreateNamespace(namespaceName) - if err != nil { - log.Error(). - Err(err). - Caller(). - Msgf("could not create new namespace '%s'", namespaceName) - ctx.String( - http.StatusInternalServerError, - "could not create new namespace", - ) - - return - } - } else if err != nil { - log.Error(). - Caller(). - Err(err). - Str("namespace", namespaceName). - Msg("could not find or create namespace") - ctx.String( - http.StatusInternalServerError, - "could not find or create namespace", - ) - - return - } - - ips, err := h.getAvailableIPs() if err != nil { log.Error(). - Caller(). Err(err). - Msg("could not get an IP from the pool") + Caller(). + Msgf("could not create new namespace '%s'", namespaceName) ctx.String( http.StatusInternalServerError, - "could not get an IP from the pool", + "could not create new namespace", ) return } + } else if err != nil { + log.Error(). + Caller(). + Err(err). + Str("namespace", namespaceName). + Msg("could not find or create namespace") + ctx.String( + http.StatusInternalServerError, + "could not find or create namespace", + ) - machine.IPAddresses = ips - machine.NamespaceID = namespace.ID - machine.Registered = true - machine.RegisterMethod = RegisterMethodOIDC - machine.LastSuccessfulUpdate = &now - machine.Expiry = &requestedTime - h.db.Save(&machine) + return + } + + machineKeyStr := MachinePublicKeyStripPrefix(machineKey) + + _, err = h.RegisterMachineFromAuthCallback( + machineKeyStr, + namespace.Name, + RegisterMethodOIDC, + ) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("could not register machine") + ctx.String( + http.StatusInternalServerError, + "could not register machine", + ) + + return } var content bytes.Buffer diff --git a/preauth_keys.go b/preauth_keys.go index 50bc474..55f6222 100644 --- a/preauth_keys.go +++ b/preauth_keys.go @@ -113,6 +113,12 @@ func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error { return nil } +// UsePreAuthKey marks a PreAuthKey as used. +func (h *Headscale) UsePreAuthKey(k *PreAuthKey) { + k.Used = true + h.db.Save(k) +} + // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { diff --git a/preauth_keys_test.go b/preauth_keys_test.go index f8cf276..6f233a9 100644 --- a/preauth_keys_test.go +++ b/preauth_keys_test.go @@ -80,7 +80,6 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { DiscoKey: "faa", Name: "testest", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } @@ -105,7 +104,6 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { DiscoKey: "faa", Name: "testest", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } @@ -143,7 +141,6 @@ func (*Suite) TestEphemeralKey(c *check.C) { DiscoKey: "faa", Name: "testest", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, LastSeen: &now, AuthKeyID: uint(pak.ID), diff --git a/proto/headscale/v1/machine.proto b/proto/headscale/v1/machine.proto index 7451db8..75552d7 100644 --- a/proto/headscale/v1/machine.proto +++ b/proto/headscale/v1/machine.proto @@ -22,16 +22,16 @@ message Machine { string name = 6; Namespace namespace = 7; - bool registered = 8; - RegisterMethod register_method = 9; - google.protobuf.Timestamp last_seen = 10; - google.protobuf.Timestamp last_successful_update = 11; - google.protobuf.Timestamp expiry = 12; + google.protobuf.Timestamp last_seen = 8; + google.protobuf.Timestamp last_successful_update = 9; + google.protobuf.Timestamp expiry = 10; - PreAuthKey pre_auth_key = 13; + PreAuthKey pre_auth_key = 11; - google.protobuf.Timestamp created_at = 14; + google.protobuf.Timestamp created_at = 12; + + RegisterMethod register_method = 13; // google.protobuf.Timestamp updated_at = 14; // google.protobuf.Timestamp deleted_at = 15; diff --git a/routes_test.go b/routes_test.go index 94bda45..1c24bf2 100644 --- a/routes_test.go +++ b/routes_test.go @@ -35,7 +35,6 @@ func (s *Suite) TestGetRoutes(c *check.C) { DiscoKey: "faa", Name: "test_get_route_machine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), HostInfo: datatypes.JSON(hostinfo), @@ -89,7 +88,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { DiscoKey: "faa", Name: "test_enable_route_machine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), HostInfo: datatypes.JSON(hostinfo), diff --git a/utils_test.go b/utils_test.go index 271013a..7c72c09 100644 --- a/utils_test.go +++ b/utils_test.go @@ -36,7 +36,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), IPAddresses: ips, @@ -85,7 +84,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), IPAddresses: ips, @@ -176,7 +174,6 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: namespace.ID, - Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), }