rename acl "get" funcs to "expand" for consistency
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
155cc072f7
commit
19dc0ac702
4 changed files with 59 additions and 62 deletions
|
@ -338,7 +338,7 @@ func (api headscaleV1APIServer) ListMachines(
|
|||
response := make([]*v1.Machine, len(machines))
|
||||
for index, machine := range machines {
|
||||
m := machine.Proto()
|
||||
validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine(
|
||||
validTags, invalidTags := api.h.ACLPolicy.TagsOfMachine(
|
||||
machine,
|
||||
)
|
||||
m.InvalidTags = invalidTags
|
||||
|
|
|
@ -104,7 +104,7 @@ func tailNode(
|
|||
|
||||
online := machine.IsOnline()
|
||||
|
||||
tags, _ := pol.GetTagsOfMachine(machine)
|
||||
tags, _ := pol.TagsOfMachine(machine)
|
||||
tags = lo.Uniq(append(tags, machine.ForcedTags...))
|
||||
|
||||
node := tailcfg.Node{
|
||||
|
|
|
@ -114,9 +114,6 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
|
|||
return &policy, nil
|
||||
}
|
||||
|
||||
// TODO(kradalby): This needs to be replace with something that generates
|
||||
// the rules as needed and not stores it on the global object, rules are
|
||||
// per node and that should be taken into account.
|
||||
func GenerateFilterAndSSHRules(
|
||||
policy *ACLPolicy,
|
||||
machine *types.Machine,
|
||||
|
@ -169,7 +166,7 @@ func (pol *ACLPolicy) generateFilterRules(
|
|||
|
||||
srcIPs := []string{}
|
||||
for srcIndex, src := range acl.Sources {
|
||||
srcs, err := pol.getIPsFromSource(src, machines)
|
||||
srcs, err := pol.expandSource(src, machines)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Interface("src", src).
|
||||
|
@ -338,7 +335,7 @@ func (pol *ACLPolicy) generateSSHRules(
|
|||
Any: true,
|
||||
})
|
||||
} else if isGroup(rawSrc) {
|
||||
users, err := pol.getUsersInGroup(rawSrc)
|
||||
users, err := pol.expandUsersFromGroup(rawSrc)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
|
||||
|
@ -401,26 +398,6 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// getIPsFromSource returns a set of Source IPs that would be associated
|
||||
// with the given src alias.
|
||||
func (pol *ACLPolicy) getIPsFromSource(
|
||||
src string,
|
||||
machines types.Machines,
|
||||
) ([]string, error) {
|
||||
ipSet, err := pol.ExpandAlias(machines, src)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
||||
prefixes := []string{}
|
||||
|
||||
for _, prefix := range ipSet.Prefixes() {
|
||||
prefixes = append(prefixes, prefix.String())
|
||||
}
|
||||
|
||||
return prefixes, nil
|
||||
}
|
||||
|
||||
func parseDestination(dest string) (string, string, error) {
|
||||
var tokens []string
|
||||
|
||||
|
@ -520,6 +497,26 @@ func parseProtocol(protocol string) ([]int, bool, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// expandSource returns a set of Source IPs that would be associated
|
||||
// with the given src alias.
|
||||
func (pol *ACLPolicy) expandSource(
|
||||
src string,
|
||||
machines types.Machines,
|
||||
) ([]string, error) {
|
||||
ipSet, err := pol.ExpandAlias(machines, src)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
||||
prefixes := []string{}
|
||||
|
||||
for _, prefix := range ipSet.Prefixes() {
|
||||
prefixes = append(prefixes, prefix.String())
|
||||
}
|
||||
|
||||
return prefixes, nil
|
||||
}
|
||||
|
||||
// expandalias has an input of either
|
||||
// - a user
|
||||
// - a group
|
||||
|
@ -544,16 +541,16 @@ func (pol *ACLPolicy) ExpandAlias(
|
|||
|
||||
// if alias is a group
|
||||
if isGroup(alias) {
|
||||
return pol.getIPsFromGroup(alias, machines)
|
||||
return pol.expandIPsFromGroup(alias, machines)
|
||||
}
|
||||
|
||||
// if alias is a tag
|
||||
if isTag(alias) {
|
||||
return pol.getIPsFromTag(alias, machines)
|
||||
return pol.expandIPsFromTag(alias, machines)
|
||||
}
|
||||
|
||||
// if alias is a user
|
||||
if ips, err := pol.getIPsForUser(alias, machines); ips != nil {
|
||||
if ips, err := pol.expandIPsFromUser(alias, machines); ips != nil {
|
||||
return ips, err
|
||||
}
|
||||
|
||||
|
@ -567,12 +564,12 @@ func (pol *ACLPolicy) ExpandAlias(
|
|||
|
||||
// if alias is an IP
|
||||
if ip, err := netip.ParseAddr(alias); err == nil {
|
||||
return pol.getIPsFromSingleIP(ip, machines)
|
||||
return pol.expandIPsFromSingleIP(ip, machines)
|
||||
}
|
||||
|
||||
// if alias is an IP Prefix (CIDR)
|
||||
if prefix, err := netip.ParsePrefix(alias); err == nil {
|
||||
return pol.getIPsFromIPPrefix(prefix, machines)
|
||||
return pol.expandIPsFromIPPrefix(prefix, machines)
|
||||
}
|
||||
|
||||
log.Warn().Msgf("No IPs found with the alias %v", alias)
|
||||
|
@ -591,7 +588,7 @@ func excludeCorrectlyTaggedNodes(
|
|||
out := types.Machines{}
|
||||
tags := []string{}
|
||||
for tag := range aclPolicy.TagOwners {
|
||||
owners, _ := getTagOwners(aclPolicy, user)
|
||||
owners, _ := expandOwnersFromTag(aclPolicy, user)
|
||||
ns := append(owners, user)
|
||||
if util.StringOrPrefixListContains(ns, user) {
|
||||
tags = append(tags, tag)
|
||||
|
@ -668,20 +665,9 @@ func expandPorts(portsStr string, isWild bool) (*[]tailcfg.PortRange, error) {
|
|||
return &ports, nil
|
||||
}
|
||||
|
||||
func filterMachinesByUser(machines types.Machines, user string) types.Machines {
|
||||
out := types.Machines{}
|
||||
for _, machine := range machines {
|
||||
if machine.User.Name == user {
|
||||
out = append(out, machine)
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// getTagOwners will return a list of user. An owner can be either a user or a group
|
||||
// expandOwnersFromTag will return a list of user. An owner can be either a user or a group
|
||||
// a group cannot be composed of groups.
|
||||
func getTagOwners(
|
||||
func expandOwnersFromTag(
|
||||
pol *ACLPolicy,
|
||||
tag string,
|
||||
) ([]string, error) {
|
||||
|
@ -696,7 +682,7 @@ func getTagOwners(
|
|||
}
|
||||
for _, owner := range ows {
|
||||
if isGroup(owner) {
|
||||
gs, err := pol.getUsersInGroup(owner)
|
||||
gs, err := pol.expandUsersFromGroup(owner)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
@ -709,9 +695,9 @@ func getTagOwners(
|
|||
return owners, nil
|
||||
}
|
||||
|
||||
// getUsersInGroup will return the list of user inside the group
|
||||
// expandUsersFromGroup will return the list of user inside the group
|
||||
// after some validation.
|
||||
func (pol *ACLPolicy) getUsersInGroup(
|
||||
func (pol *ACLPolicy) expandUsersFromGroup(
|
||||
group string,
|
||||
) ([]string, error) {
|
||||
users := []string{}
|
||||
|
@ -745,13 +731,13 @@ func (pol *ACLPolicy) getUsersInGroup(
|
|||
return users, nil
|
||||
}
|
||||
|
||||
func (pol *ACLPolicy) getIPsFromGroup(
|
||||
func (pol *ACLPolicy) expandIPsFromGroup(
|
||||
group string,
|
||||
machines types.Machines,
|
||||
) (*netipx.IPSet, error) {
|
||||
build := netipx.IPSetBuilder{}
|
||||
|
||||
users, err := pol.getUsersInGroup(group)
|
||||
users, err := pol.expandUsersFromGroup(group)
|
||||
if err != nil {
|
||||
return &netipx.IPSet{}, err
|
||||
}
|
||||
|
@ -765,7 +751,7 @@ func (pol *ACLPolicy) getIPsFromGroup(
|
|||
return build.IPSet()
|
||||
}
|
||||
|
||||
func (pol *ACLPolicy) getIPsFromTag(
|
||||
func (pol *ACLPolicy) expandIPsFromTag(
|
||||
alias string,
|
||||
machines types.Machines,
|
||||
) (*netipx.IPSet, error) {
|
||||
|
@ -779,7 +765,7 @@ func (pol *ACLPolicy) getIPsFromTag(
|
|||
}
|
||||
|
||||
// find tag owners
|
||||
owners, err := getTagOwners(pol, alias)
|
||||
owners, err := expandOwnersFromTag(pol, alias)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInvalidTag) {
|
||||
ipSet, _ := build.IPSet()
|
||||
|
@ -811,7 +797,7 @@ func (pol *ACLPolicy) getIPsFromTag(
|
|||
return build.IPSet()
|
||||
}
|
||||
|
||||
func (pol *ACLPolicy) getIPsForUser(
|
||||
func (pol *ACLPolicy) expandIPsFromUser(
|
||||
user string,
|
||||
machines types.Machines,
|
||||
) (*netipx.IPSet, error) {
|
||||
|
@ -832,7 +818,7 @@ func (pol *ACLPolicy) getIPsForUser(
|
|||
return build.IPSet()
|
||||
}
|
||||
|
||||
func (pol *ACLPolicy) getIPsFromSingleIP(
|
||||
func (pol *ACLPolicy) expandIPsFromSingleIP(
|
||||
ip netip.Addr,
|
||||
machines types.Machines,
|
||||
) (*netipx.IPSet, error) {
|
||||
|
@ -850,7 +836,7 @@ func (pol *ACLPolicy) getIPsFromSingleIP(
|
|||
return build.IPSet()
|
||||
}
|
||||
|
||||
func (pol *ACLPolicy) getIPsFromIPPrefix(
|
||||
func (pol *ACLPolicy) expandIPsFromIPPrefix(
|
||||
prefix netip.Prefix,
|
||||
machines types.Machines,
|
||||
) (*netipx.IPSet, error) {
|
||||
|
@ -885,10 +871,10 @@ func isTag(str string) bool {
|
|||
return strings.HasPrefix(str, "tag:")
|
||||
}
|
||||
|
||||
// getTags will return the tags of the current machine.
|
||||
// TagsOfMachine will return the tags of the current machine.
|
||||
// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag.
|
||||
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.
|
||||
func (pol *ACLPolicy) GetTagsOfMachine(
|
||||
func (pol *ACLPolicy) TagsOfMachine(
|
||||
machine types.Machine,
|
||||
) ([]string, []string) {
|
||||
validTags := make([]string, 0)
|
||||
|
@ -897,7 +883,7 @@ func (pol *ACLPolicy) GetTagsOfMachine(
|
|||
validTagMap := make(map[string]bool)
|
||||
invalidTagMap := make(map[string]bool)
|
||||
for _, tag := range machine.HostInfo.RequestTags {
|
||||
owners, err := getTagOwners(pol, tag)
|
||||
owners, err := expandOwnersFromTag(pol, tag)
|
||||
if errors.Is(err, ErrInvalidTag) {
|
||||
invalidTagMap[tag] = true
|
||||
|
||||
|
@ -925,6 +911,17 @@ func (pol *ACLPolicy) GetTagsOfMachine(
|
|||
return validTags, invalidTags
|
||||
}
|
||||
|
||||
func filterMachinesByUser(machines types.Machines, user string) types.Machines {
|
||||
out := types.Machines{}
|
||||
for _, machine := range machines {
|
||||
if machine.User.Name == user {
|
||||
out = append(out, machine)
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
|
||||
func FilterMachinesByACL(
|
||||
machine *types.Machine,
|
||||
|
|
|
@ -690,7 +690,7 @@ func Test_expandGroup(t *testing.T) {
|
|||
t.Run(test.name, func(t *testing.T) {
|
||||
viper.Set("oidc.strip_email_domain", test.args.stripEmail)
|
||||
|
||||
got, err := test.field.pol.getUsersInGroup(
|
||||
got, err := test.field.pol.expandUsersFromGroup(
|
||||
test.args.group,
|
||||
)
|
||||
|
||||
|
@ -779,7 +779,7 @@ func Test_expandTagOwners(t *testing.T) {
|
|||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got, err := getTagOwners(
|
||||
got, err := expandOwnersFromTag(
|
||||
test.args.aclPolicy,
|
||||
test.args.tag,
|
||||
)
|
||||
|
@ -2022,7 +2022,7 @@ func Test_getTags(t *testing.T) {
|
|||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine(
|
||||
gotValid, gotInvalid := test.args.aclPolicy.TagsOfMachine(
|
||||
test.args.machine,
|
||||
)
|
||||
for _, valid := range gotValid {
|
||||
|
|
Loading…
Reference in a new issue