Set tags as part of handleAuthKeyCommon
This commit is contained in:
parent
6faa1d2e4a
commit
ac18723dd4
5 changed files with 75 additions and 6 deletions
12
grpcv1.go
12
grpcv1.go
|
@ -106,6 +106,18 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
|
||||||
expiration = request.GetExpiration().AsTime()
|
expiration = request.GetExpiration().AsTime()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(request.AclTags) > 0 {
|
||||||
|
for _, tag := range request.AclTags {
|
||||||
|
err := validateTag(tag)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return &v1.CreatePreAuthKeyResponse{
|
||||||
|
PreAuthKey: nil,
|
||||||
|
}, status.Error(codes.InvalidArgument, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
preAuthKey, err := api.h.CreatePreAuthKey(
|
preAuthKey, err := api.h.CreatePreAuthKey(
|
||||||
request.GetNamespace(),
|
request.GetNamespace(),
|
||||||
request.GetReusable(),
|
request.GetReusable(),
|
||||||
|
|
|
@ -260,6 +260,8 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
|
||||||
"24h",
|
"24h",
|
||||||
"--output",
|
"--output",
|
||||||
"json",
|
"json",
|
||||||
|
"--tags",
|
||||||
|
"tag:test1,tag:test2",
|
||||||
},
|
},
|
||||||
[]string{},
|
[]string{},
|
||||||
)
|
)
|
||||||
|
@ -333,6 +335,11 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
|
||||||
listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Test that tags are present
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
assert.DeepEquals(listedPreAuthKeys[i].AclTags, []string{"tag:test1,", "tag:test2"})
|
||||||
|
}
|
||||||
|
|
||||||
// Expire three keys
|
// Expire three keys
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
_, err := ExecuteCommand(
|
_, err := ExecuteCommand(
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
@ -55,6 +56,12 @@ func (h *Headscale) CreatePreAuthKey(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, tag := range aclTags {
|
||||||
|
if !strings.HasPrefix(tag, "tag:") {
|
||||||
|
return nil, fmt.Errorf("aclTag '%s' did not begin with 'tag:'", tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
kstr, err := h.generateKey()
|
kstr, err := h.generateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -77,13 +84,18 @@ func (h *Headscale) CreatePreAuthKey(
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(aclTags) > 0 {
|
if len(aclTags) > 0 {
|
||||||
|
seenTags := map[string]bool{}
|
||||||
|
|
||||||
for _, tag := range aclTags {
|
for _, tag := range aclTags {
|
||||||
|
if seenTags[tag] == false {
|
||||||
if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"failed to create key tag in the database: %w",
|
"failed to ceate key tag in the database: %w",
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
seenTags[tag] = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -222,7 +234,7 @@ func (key *PreAuthKey) toProto() *v1.PreAuthKey {
|
||||||
|
|
||||||
if len(key.AclTags) > 0 {
|
if len(key.AclTags) > 0 {
|
||||||
for idx := range key.AclTags {
|
for idx := range key.AclTags {
|
||||||
protoKey.AclTags[idx] = key.AclTags[0].Tag
|
protoKey.AclTags[idx] = key.AclTags[idx].Tag
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -190,3 +190,20 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||||
_, err = app.checkKeyValidity(pak.Key)
|
_, err = app.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*Suite) TestPreAuthKeyAclTags(c *check.C) {
|
||||||
|
namespace, err := app.CreateNamespace("test8")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, []string{"badtag"})
|
||||||
|
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
||||||
|
|
||||||
|
tags := []string{"tag:test1", "tag:test2"}
|
||||||
|
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
||||||
|
_, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, tagsWithDuplicate)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
listedPaks, err := app.ListPreAuthKeys("test8")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags)
|
||||||
|
}
|
||||||
|
|
|
@ -345,6 +345,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
machine.NodeKey = nodeKey
|
machine.NodeKey = nodeKey
|
||||||
machine.AuthKeyID = uint(pak.ID)
|
machine.AuthKeyID = uint(pak.ID)
|
||||||
err := h.RefreshMachine(machine, registerRequest.Expiry)
|
err := h.RefreshMachine(machine, registerRequest.Expiry)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -355,6 +356,25 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
aclTags := pak.toProto().AclTags
|
||||||
|
if len(aclTags) > 0 {
|
||||||
|
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
||||||
|
err = h.SetTags(machine, aclTags)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Bool("noise", machineKey.IsZero()).
|
||||||
|
Str("machine", machine.Hostname).
|
||||||
|
Strs("aclTags", aclTags).
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to set tags after refreshing machine")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
@ -380,6 +400,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
NodeKey: nodeKey,
|
NodeKey: nodeKey,
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
|
ForcedTags: pak.toProto().AclTags,
|
||||||
}
|
}
|
||||||
|
|
||||||
machine, err = h.RegisterMachine(
|
machine, err = h.RegisterMachine(
|
||||||
|
|
Loading…
Reference in a new issue