Set tags as part of handleAuthKeyCommon

This commit is contained in:
Benjamin George Roberts 2022-08-25 20:43:15 +10:00
parent 6faa1d2e4a
commit ac18723dd4
5 changed files with 75 additions and 6 deletions

View file

@ -106,6 +106,18 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
expiration = request.GetExpiration().AsTime() expiration = request.GetExpiration().AsTime()
} }
if len(request.AclTags) > 0 {
for _, tag := range request.AclTags {
err := validateTag(tag)
if err != nil {
return &v1.CreatePreAuthKeyResponse{
PreAuthKey: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
}
}
preAuthKey, err := api.h.CreatePreAuthKey( preAuthKey, err := api.h.CreatePreAuthKey(
request.GetNamespace(), request.GetNamespace(),
request.GetReusable(), request.GetReusable(),

View file

@ -260,6 +260,8 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
"24h", "24h",
"--output", "--output",
"json", "json",
"--tags",
"tag:test1,tag:test2",
}, },
[]string{}, []string{},
) )
@ -333,6 +335,11 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)), listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
) )
// Test that tags are present
for i := 0; i < count; i++ {
assert.DeepEquals(listedPreAuthKeys[i].AclTags, []string{"tag:test1,", "tag:test2"})
}
// Expire three keys // Expire three keys
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
_, err := ExecuteCommand( _, err := ExecuteCommand(

View file

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -55,6 +56,12 @@ func (h *Headscale) CreatePreAuthKey(
return nil, err return nil, err
} }
for _, tag := range aclTags {
if !strings.HasPrefix(tag, "tag:") {
return nil, fmt.Errorf("aclTag '%s' did not begin with 'tag:'", tag)
}
}
now := time.Now().UTC() now := time.Now().UTC()
kstr, err := h.generateKey() kstr, err := h.generateKey()
if err != nil { if err != nil {
@ -77,13 +84,18 @@ func (h *Headscale) CreatePreAuthKey(
} }
if len(aclTags) > 0 { if len(aclTags) > 0 {
seenTags := map[string]bool{}
for _, tag := range aclTags { for _, tag := range aclTags {
if seenTags[tag] == false {
if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return fmt.Errorf( return fmt.Errorf(
"failed to create key tag in the database: %w", "failed to ceate key tag in the database: %w",
err, err,
) )
} }
seenTags[tag] = true
}
} }
} }
return nil return nil
@ -222,7 +234,7 @@ func (key *PreAuthKey) toProto() *v1.PreAuthKey {
if len(key.AclTags) > 0 { if len(key.AclTags) > 0 {
for idx := range key.AclTags { for idx := range key.AclTags {
protoKey.AclTags[idx] = key.AclTags[0].Tag protoKey.AclTags[idx] = key.AclTags[idx].Tag
} }
} }

View file

@ -190,3 +190,20 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
_, err = app.checkKeyValidity(pak.Key) _, err = app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
} }
func (*Suite) TestPreAuthKeyAclTags(c *check.C) {
namespace, err := app.CreateNamespace("test8")
c.Assert(err, check.IsNil)
_, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, []string{"badtag"})
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
tags := []string{"tag:test1", "tag:test2"}
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
_, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, tagsWithDuplicate)
c.Assert(err, check.IsNil)
listedPaks, err := app.ListPreAuthKeys("test8")
c.Assert(err, check.IsNil)
c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags)
}

View file

@ -345,6 +345,7 @@ func (h *Headscale) handleAuthKeyCommon(
machine.NodeKey = nodeKey machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID) machine.AuthKeyID = uint(pak.ID)
err := h.RefreshMachine(machine, registerRequest.Expiry) err := h.RefreshMachine(machine, registerRequest.Expiry)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -355,6 +356,25 @@ func (h *Headscale) handleAuthKeyCommon(
return return
} }
aclTags := pak.toProto().AclTags
if len(aclTags) > 0 {
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
err = h.SetTags(machine, aclTags)
}
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Strs("aclTags", aclTags).
Err(err).
Msg("Failed to set tags after refreshing machine")
return
}
} else { } else {
now := time.Now().UTC() now := time.Now().UTC()
@ -380,6 +400,7 @@ func (h *Headscale) handleAuthKeyCommon(
NodeKey: nodeKey, NodeKey: nodeKey,
LastSeen: &now, LastSeen: &now,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
ForcedTags: pak.toProto().AclTags,
} }
machine, err = h.RegisterMachine( machine, err = h.RegisterMachine(