From 16b21e8158d7a32d1ed1b093f05268a633a5bd17 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 28 Feb 2022 16:55:57 +0000 Subject: [PATCH] Remove all references to Machine.Registered --- machine.go | 24 +++------------- oidc.go | 81 +++++++++++++++++++++++++++--------------------------- 2 files changed, 44 insertions(+), 61 deletions(-) diff --git a/machine.go b/machine.go index 6e5b62d..16d6d86 100644 --- a/machine.go +++ b/machine.go @@ -72,11 +72,6 @@ type ( MachinesP []*Machine ) -// For the time being this method is rather naive. -func (machine Machine) isRegistered() bool { - return machine.Registered -} - type MachineAddresses []netaddr.IP func (ma MachineAddresses) ToStringSlice() []string { @@ -221,7 +216,7 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) { Msg("Finding direct peers") machines := Machines{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where("machine_key <> ? AND registered", + if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where("machine_key <> ?", machine.MachineKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") @@ -284,7 +279,7 @@ func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) { } for _, peer := range peers { - if peer.isRegistered() && !peer.isExpired() { + if !peer.isExpired() { validPeers = append(validPeers, peer) } } @@ -373,8 +368,6 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) { // DeleteMachine softs deletes a Machine from the database. func (h *Headscale) DeleteMachine(machine *Machine) error { - machine.Registered = false - h.db.Save(&machine) // we mark it as unregistered, just in case if err := h.db.Delete(&machine).Error; err != nil { return err } @@ -642,7 +635,7 @@ func (machine Machine) toNode( LastSeen: machine.LastSeen, KeepAlive: true, - MachineAuthorized: machine.Registered, + MachineAuthorized: !machine.isExpired(), Capabilities: []string{tailcfg.CapabilityFileSharing}, } @@ -660,8 +653,6 @@ func (machine *Machine) toProto() *v1.Machine { Name: machine.Name, Namespace: machine.Namespace.toProto(), - Registered: machine.Registered, - // TODO(kradalby): Implement register method enum converter // RegisterMethod: , @@ -738,7 +729,7 @@ func (h *Headscale) RegisterMachine(machine Machine, []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), ) machineFromDatabase, _ := h.GetMachineByMachineKey(machineKey) - if machine.isRegistered() { + if machineFromDatabase != nil { log.Trace(). Caller(). Str("machine", machine.Name). @@ -774,13 +765,6 @@ func (h *Headscale) RegisterMachine(machine Machine, // move it to tags instead of having a column. // machine.RegisterMethod = registrationMethod - // TODO(kradalby): Registered is a very frustrating value - // to keep up to date, and it makes is have to care if a - // machine is registered, authenticated and expired. - // Let us simplify the model, a machine is _only_ saved if - // it is registered. - machine.Registered = true - h.db.Save(&machine) log.Trace(). diff --git a/oidc.go b/oidc.go index f291e8c..e745236 100644 --- a/oidc.go +++ b/oidc.go @@ -248,7 +248,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - if machine.isRegistered() { + if machine != nil { log.Trace(). Caller(). Str("machine", machine.Name). @@ -291,58 +291,57 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } + // register the machine if it's new - if !machine.Registered { - log.Debug().Msg("Registering new machine after successful callback") + log.Debug().Msg("Registering new machine after successful callback") - namespace, err := h.GetNamespace(namespaceName) - if errors.Is(err, errNamespaceNotFound) { - namespace, err = h.CreateNamespace(namespaceName) + namespace, err := h.GetNamespace(namespaceName) + if errors.Is(err, errNamespaceNotFound) { + namespace, err = h.CreateNamespace(namespaceName) - if err != nil { - log.Error(). - Err(err). - Caller(). - Msgf("could not create new namespace '%s'", namespaceName) - ctx.String( - http.StatusInternalServerError, - "could not create new namespace", - ) - - return - } - } else if err != nil { - log.Error(). - Caller(). - Err(err). - Str("namespace", namespaceName). - Msg("could not find or create namespace") - ctx.String( - http.StatusInternalServerError, - "could not find or create namespace", - ) - - return - } - - _, err = h.RegisterMachineFromAuthCallback( - machineKeyStr, - namespace.Name, - RegisterMethodOIDC, - &requestedTime, - ) if err != nil { log.Error(). - Caller(). Err(err). - Msg("could not register machine") + Caller(). + Msgf("could not create new namespace '%s'", namespaceName) ctx.String( http.StatusInternalServerError, - "could not register machine", + "could not create new namespace", ) return } + } else if err != nil { + log.Error(). + Caller(). + Err(err). + Str("namespace", namespaceName). + Msg("could not find or create namespace") + ctx.String( + http.StatusInternalServerError, + "could not find or create namespace", + ) + + return + } + + _, err = h.RegisterMachineFromAuthCallback( + machineKeyStr, + namespace.Name, + RegisterMethodOIDC, + &requestedTime, + ) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("could not register machine") + ctx.String( + http.StatusInternalServerError, + "could not register machine", + ) + + return } var content bytes.Buffer