diff --git a/grpcv1.go b/grpcv1.go index 284c175..676c0e7 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -3,8 +3,6 @@ package headscale import ( "context" - "fmt" - "strconv" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -184,35 +182,23 @@ func (api headscaleV1APIServer) GetMachine( return &v1.GetMachineResponse{Machine: machine.toProto()}, nil } -func (api headscaleV1APIServer) UpdateMachine( +func (api headscaleV1APIServer) SetTags( ctx context.Context, - request *v1.UpdateMachineRequest, -) (*v1.UpdateMachineResponse, error) { - rMachine := request.GetMachine() - machine, err := api.h.GetMachineByID(rMachine.Id) + request *v1.SetTagsRequest, +) (*v1.SetTagsResponse, error) { + machine, err := api.h.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } - machine.ForcedTags = rMachine.ForcedTags - machine.Name = rMachine.Name - id, err := strconv.Atoi(rMachine.Namespace.Id) - if err != nil { - return nil, fmt.Errorf("failed to convert namespace id to integer: %w", err) - } - machine.NamespaceID = uint(id) + api.h.SetTags(machine, request.GetTags()) - err = api.h.UpdateDBMachine(*machine) - if err != nil { - return nil, err - } + log.Trace(). + Str("machine", machine.Name). + Strs("tags", request.GetTags()). + Msg("Changing tags of machine") - machine, err = api.h.GetMachineByID(rMachine.Id) - if err != nil { - return nil, err - } - - return &v1.UpdateMachineResponse{Machine: machine.toProto()}, nil + return &v1.SetTagsResponse{Machine: machine.toProto()}, nil } func (api headscaleV1APIServer) DeleteMachine( diff --git a/machine.go b/machine.go index 9fc06a4..2b5da37 100644 --- a/machine.go +++ b/machine.go @@ -360,18 +360,15 @@ func (h *Headscale) UpdateMachine(machine *Machine) error { return nil } -// UpdateDBMachine takes a Machine struct pointer (typically already loaded from database -// search for the same machine in the database and update the latter. -func (h *Headscale) UpdateDBMachine(machine Machine) error { - destMachine := Machine{} - if result := h.db.Where("id = ?", machine.ID).Find(&destMachine); result.Error != nil { - return result.Error +// SetTags takes a Machine struct pointer and update the forced tags. +func (h *Headscale) SetTags(machine *Machine, tags []string) error { + machine.ForcedTags = tags + err := h.UpdateACLRules() + if err != nil { + return err } - destMachine.Name = machine.Name - destMachine.NamespaceID = machine.NamespaceID - destMachine.ForcedTags = machine.ForcedTags - - h.db.Save(destMachine) + h.setLastStateChangeToNow(machine.Namespace.Name) + h.db.Save(machine) return nil }