Set tags as part of handleAuthKeyCommon
This commit is contained in:
parent
6faa1d2e4a
commit
ac18723dd4
5 changed files with 75 additions and 6 deletions
12
grpcv1.go
12
grpcv1.go
|
@ -106,6 +106,18 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
|
|||
expiration = request.GetExpiration().AsTime()
|
||||
}
|
||||
|
||||
if len(request.AclTags) > 0 {
|
||||
for _, tag := range request.AclTags {
|
||||
err := validateTag(tag)
|
||||
|
||||
if err != nil {
|
||||
return &v1.CreatePreAuthKeyResponse{
|
||||
PreAuthKey: nil,
|
||||
}, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preAuthKey, err := api.h.CreatePreAuthKey(
|
||||
request.GetNamespace(),
|
||||
request.GetReusable(),
|
||||
|
|
|
@ -260,6 +260,8 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
|
|||
"24h",
|
||||
"--output",
|
||||
"json",
|
||||
"--tags",
|
||||
"tag:test1,tag:test2",
|
||||
},
|
||||
[]string{},
|
||||
)
|
||||
|
@ -333,6 +335,11 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
|
|||
listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||
)
|
||||
|
||||
// Test that tags are present
|
||||
for i := 0; i < count; i++ {
|
||||
assert.DeepEquals(listedPreAuthKeys[i].AclTags, []string{"tag:test1,", "tag:test2"})
|
||||
}
|
||||
|
||||
// Expire three keys
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := ExecuteCommand(
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
|
@ -55,6 +56,12 @@ func (h *Headscale) CreatePreAuthKey(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
for _, tag := range aclTags {
|
||||
if !strings.HasPrefix(tag, "tag:") {
|
||||
return nil, fmt.Errorf("aclTag '%s' did not begin with 'tag:'", tag)
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
kstr, err := h.generateKey()
|
||||
if err != nil {
|
||||
|
@ -77,12 +84,17 @@ func (h *Headscale) CreatePreAuthKey(
|
|||
}
|
||||
|
||||
if len(aclTags) > 0 {
|
||||
seenTags := map[string]bool{}
|
||||
|
||||
for _, tag := range aclTags {
|
||||
if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to create key tag in the database: %w",
|
||||
err,
|
||||
)
|
||||
if seenTags[tag] == false {
|
||||
if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to ceate key tag in the database: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
seenTags[tag] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -222,7 +234,7 @@ func (key *PreAuthKey) toProto() *v1.PreAuthKey {
|
|||
|
||||
if len(key.AclTags) > 0 {
|
||||
for idx := range key.AclTags {
|
||||
protoKey.AclTags[idx] = key.AclTags[0].Tag
|
||||
protoKey.AclTags[idx] = key.AclTags[idx].Tag
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -190,3 +190,20 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
|||
_, err = app.checkKeyValidity(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||
}
|
||||
|
||||
func (*Suite) TestPreAuthKeyAclTags(c *check.C) {
|
||||
namespace, err := app.CreateNamespace("test8")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, []string{"badtag"})
|
||||
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
||||
|
||||
tags := []string{"tag:test1", "tag:test2"}
|
||||
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
||||
_, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, tagsWithDuplicate)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
listedPaks, err := app.ListPreAuthKeys("test8")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags)
|
||||
}
|
||||
|
|
|
@ -345,6 +345,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
machine.NodeKey = nodeKey
|
||||
machine.AuthKeyID = uint(pak.ID)
|
||||
err := h.RefreshMachine(machine, registerRequest.Expiry)
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -355,6 +356,25 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
aclTags := pak.toProto().AclTags
|
||||
if len(aclTags) > 0 {
|
||||
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
||||
err = h.SetTags(machine, aclTags)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Bool("noise", machineKey.IsZero()).
|
||||
Str("machine", machine.Hostname).
|
||||
Strs("aclTags", aclTags).
|
||||
Err(err).
|
||||
Msg("Failed to set tags after refreshing machine")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
now := time.Now().UTC()
|
||||
|
||||
|
@ -380,6 +400,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
|||
NodeKey: nodeKey,
|
||||
LastSeen: &now,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
ForcedTags: pak.toProto().AclTags,
|
||||
}
|
||||
|
||||
machine, err = h.RegisterMachine(
|
||||
|
|
Loading…
Reference in a new issue