Merge pull request #689 from restanrm/fix-duplicate-tags-returned-by-api
Remove duplicate tags if sent by the client
This commit is contained in:
commit
32a8f06486
5 changed files with 110 additions and 7 deletions
22
grpcv1.go
22
grpcv1.go
|
@ -3,6 +3,7 @@ package headscale
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -195,13 +196,11 @@ func (api headscaleV1APIServer) SetTags(
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tag := range request.GetTags() {
|
for _, tag := range request.GetTags() {
|
||||||
if strings.Index(tag, "tag:") != 0 {
|
err := validateTag(tag)
|
||||||
|
if err != nil {
|
||||||
return &v1.SetTagsResponse{
|
return &v1.SetTagsResponse{
|
||||||
Machine: nil,
|
Machine: nil,
|
||||||
}, status.Error(
|
}, status.Error(codes.InvalidArgument, err.Error())
|
||||||
codes.InvalidArgument,
|
|
||||||
"Invalid tag detected. Each tag must start with the string 'tag:'",
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,6 +219,19 @@ func (api headscaleV1APIServer) SetTags(
|
||||||
return &v1.SetTagsResponse{Machine: machine.toProto()}, nil
|
return &v1.SetTagsResponse{Machine: machine.toProto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateTag(tag string) error {
|
||||||
|
if strings.Index(tag, "tag:") != 0 {
|
||||||
|
return fmt.Errorf("tag must start with the string 'tag:'")
|
||||||
|
}
|
||||||
|
if strings.ToLower(tag) != tag {
|
||||||
|
return fmt.Errorf("tag should be lowercase")
|
||||||
|
}
|
||||||
|
if len(strings.Fields(tag)) > 1 {
|
||||||
|
return fmt.Errorf("tag should not contains space")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) DeleteMachine(
|
func (api headscaleV1APIServer) DeleteMachine(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteMachineRequest,
|
request *v1.DeleteMachineRequest,
|
||||||
|
|
42
grpcv1_test.go
Normal file
42
grpcv1_test.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package headscale
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func Test_validateTag(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
tag string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid tag",
|
||||||
|
args: args{tag: "tag:test"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tag without tag prefix",
|
||||||
|
args: args{tag: "test"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase tag",
|
||||||
|
args: args{tag: "tag:tEST"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tag that contains space",
|
||||||
|
args: args{tag: "tag:this is a spaced tag"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := validateTag(tt.args.tag); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -625,7 +625,7 @@ func (s *IntegrationCLITestSuite) TestNodeTagCommand() {
|
||||||
var errorOutput errOutput
|
var errorOutput errOutput
|
||||||
err = json.Unmarshal([]byte(wrongTagResult), &errorOutput)
|
err = json.Unmarshal([]byte(wrongTagResult), &errorOutput)
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
assert.Contains(s.T(), errorOutput.Error, "Invalid tag detected")
|
assert.Contains(s.T(), errorOutput.Error, "tag must start with the string 'tag:'")
|
||||||
|
|
||||||
// Test list all nodes after added seconds
|
// Test list all nodes after added seconds
|
||||||
listAllResult, err := ExecuteCommand(
|
listAllResult, err := ExecuteCommand(
|
||||||
|
|
|
@ -374,7 +374,13 @@ func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
|
||||||
|
|
||||||
// SetTags takes a Machine struct pointer and update the forced tags.
|
// SetTags takes a Machine struct pointer and update the forced tags.
|
||||||
func (h *Headscale) SetTags(machine *Machine, tags []string) error {
|
func (h *Headscale) SetTags(machine *Machine, tags []string) error {
|
||||||
machine.ForcedTags = tags
|
newTags := []string{}
|
||||||
|
for _, tag := range tags {
|
||||||
|
if !contains(newTags, tag) {
|
||||||
|
newTags = append(newTags, tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
machine.ForcedTags = newTags
|
||||||
if err := h.UpdateACLRules(); err != nil && !errors.Is(err, errEmptyPolicy) {
|
if err := h.UpdateACLRules(); err != nil && !errors.Is(err, errEmptyPolicy) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -280,6 +280,49 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestSetTags(c *check.C) {
|
||||||
|
namespace, err := app.CreateNamespace("test")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = app.GetMachine("test", "testmachine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
machine := &Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Hostname: "testmachine",
|
||||||
|
NamespaceID: namespace.ID,
|
||||||
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
app.db.Save(machine)
|
||||||
|
|
||||||
|
// assign simple tags
|
||||||
|
sTags := []string{"tag:test", "tag:foo"}
|
||||||
|
err = app.SetTags(machine, sTags)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
machine, err = app.GetMachine("test", "testmachine")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(machine.ForcedTags, check.DeepEquals, StringList(sTags))
|
||||||
|
|
||||||
|
// assign duplicat tags, expect no errors but no doubles in DB
|
||||||
|
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||||
|
err = app.SetTags(machine, eTags)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
machine, err = app.GetMachine("test", "testmachine")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(
|
||||||
|
machine.ForcedTags,
|
||||||
|
check.DeepEquals,
|
||||||
|
StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func Test_getTags(t *testing.T) {
|
func Test_getTags(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
aclPolicy *ACLPolicy
|
aclPolicy *ACLPolicy
|
||||||
|
|
Loading…
Reference in a new issue