From 402a76070f501bf490dd75391b72ef1743db66e8 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 28 Feb 2022 16:34:28 +0000 Subject: [PATCH] Reuse machine structure for parameters, named parameters --- machine.go | 89 +++++++++++++++++------------------------------------- 1 file changed, 27 insertions(+), 62 deletions(-) diff --git a/machine.go b/machine.go index 7425155..6e5b62d 100644 --- a/machine.go +++ b/machine.go @@ -21,7 +21,6 @@ import ( const ( errMachineNotFound = Error("machine not found") - errMachineAlreadyRegistered = Error("machine already registered") errMachineRouteIsNotAvailable = Error("route is not available on machine") errMachineAddressesInvalid = Error("failed to parse machine addresses") errMachineNotFoundRegistrationCache = Error( @@ -698,19 +697,23 @@ func (h *Headscale) RegisterMachineFromAuthCallback( ) (*Machine, error) { if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { if registrationMachine, ok := machineInterface.(Machine); ok { + namespace, err := h.GetNamespace(namespaceName) + if err != nil { + return nil, fmt.Errorf( + "failed to find namespace in register machine from auth callback, %w", + err, + ) + } + + registrationMachine.NamespaceID = namespace.ID + registrationMachine.RegisterMethod = registrationMethod + registrationMachine.Expiry = expiry + machine, err := h.RegisterMachine( - registrationMachine.Name, - machineKeyStr, - namespaceName, - registrationMethod, - expiry, - nil, - ®istrationMachine.NodeKey, - registrationMachine.LastSeen, + registrationMachine, ) return machine, err - } else { return nil, errCouldNotConvertMachineInterface } @@ -720,49 +723,30 @@ func (h *Headscale) RegisterMachineFromAuthCallback( } // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. -func (h *Headscale) RegisterMachine( - machineName string, - machineKeyStr string, - namespaceName string, - registrationMethod string, - expiry *time.Time, - - // Optionals - authKey *PreAuthKey, - nodePublicKey *string, - lastSeen *time.Time, +func (h *Headscale) RegisterMachine(machine Machine, ) (*Machine, error) { - namespace, err := h.GetNamespace(namespaceName) - if err != nil { - return nil, err - } - - var machineKey key.MachinePublic - err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) - if err != nil { - return nil, err - } - log.Trace(). Caller(). - Str("machine_key_str", machineKeyStr). - Str("machine_key", machineKey.String()). + Str("machine_key", machine.MachineKey). Msg("Registering machine") - machine, err := h.GetMachineByMachineKey(machineKey) - if err != nil { - return nil, err - } - + // If the machine is already in the database, it is seeking + // reauthentication, and by reaching this step, has been authenticated + // and need to have an updated expiry. + var machineKey key.MachinePublic + _ = machineKey.UnmarshalText( + []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), + ) + machineFromDatabase, _ := h.GetMachineByMachineKey(machineKey) if machine.isRegistered() { log.Trace(). Caller(). Str("machine", machine.Name). Msg("machine already registered, reauthenticating") - h.RefreshMachine(machine, *expiry) + h.RefreshMachine(machineFromDatabase, *machine.Expiry) - return machine, nil + return machineFromDatabase, nil } log.Trace(). @@ -786,28 +770,9 @@ func (h *Headscale) RegisterMachine( machine.IPAddresses = ips - if expiry != nil { - machine.Expiry = expiry - } - - if authKey != nil { - machine.AuthKeyID = uint(authKey.ID) - } - - if nodePublicKey != nil { - machine.NodeKey = *nodePublicKey - } - - if lastSeen != nil { - machine.LastSeen = lastSeen - } - - machine.Name = machineName - machine.NamespaceID = namespace.ID - // TODO(kradalby): This field is uneccessary metadata, // move it to tags instead of having a column. - machine.RegisterMethod = registrationMethod + // machine.RegisterMethod = registrationMethod // TODO(kradalby): Registered is a very frustrating value // to keep up to date, and it makes is have to care if a @@ -824,7 +789,7 @@ func (h *Headscale) RegisterMachine( Str("ip", strings.Join(ips.ToStringSlice(), ",")). Msg("Machine registered with the database") - return machine, nil + return &machine, nil } func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {