rename acl "get" funcs to "expand" for consistency

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-19 09:17:50 +02:00 committed by Kristoffer Dalby
parent 155cc072f7
commit 19dc0ac702
4 changed files with 59 additions and 62 deletions

View file

@ -338,7 +338,7 @@ func (api headscaleV1APIServer) ListMachines(
response := make([]*v1.Machine, len(machines)) response := make([]*v1.Machine, len(machines))
for index, machine := range machines { for index, machine := range machines {
m := machine.Proto() m := machine.Proto()
validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine( validTags, invalidTags := api.h.ACLPolicy.TagsOfMachine(
machine, machine,
) )
m.InvalidTags = invalidTags m.InvalidTags = invalidTags

View file

@ -104,7 +104,7 @@ func tailNode(
online := machine.IsOnline() online := machine.IsOnline()
tags, _ := pol.GetTagsOfMachine(machine) tags, _ := pol.TagsOfMachine(machine)
tags = lo.Uniq(append(tags, machine.ForcedTags...)) tags = lo.Uniq(append(tags, machine.ForcedTags...))
node := tailcfg.Node{ node := tailcfg.Node{

View file

@ -114,9 +114,6 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
return &policy, nil return &policy, nil
} }
// TODO(kradalby): This needs to be replace with something that generates
// the rules as needed and not stores it on the global object, rules are
// per node and that should be taken into account.
func GenerateFilterAndSSHRules( func GenerateFilterAndSSHRules(
policy *ACLPolicy, policy *ACLPolicy,
machine *types.Machine, machine *types.Machine,
@ -169,7 +166,7 @@ func (pol *ACLPolicy) generateFilterRules(
srcIPs := []string{} srcIPs := []string{}
for srcIndex, src := range acl.Sources { for srcIndex, src := range acl.Sources {
srcs, err := pol.getIPsFromSource(src, machines) srcs, err := pol.expandSource(src, machines)
if err != nil { if err != nil {
log.Error(). log.Error().
Interface("src", src). Interface("src", src).
@ -338,7 +335,7 @@ func (pol *ACLPolicy) generateSSHRules(
Any: true, Any: true,
}) })
} else if isGroup(rawSrc) { } else if isGroup(rawSrc) {
users, err := pol.getUsersInGroup(rawSrc) users, err := pol.expandUsersFromGroup(rawSrc)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex) Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
@ -401,26 +398,6 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
}, nil }, nil
} }
// getIPsFromSource returns a set of Source IPs that would be associated
// with the given src alias.
func (pol *ACLPolicy) getIPsFromSource(
src string,
machines types.Machines,
) ([]string, error) {
ipSet, err := pol.ExpandAlias(machines, src)
if err != nil {
return []string{}, err
}
prefixes := []string{}
for _, prefix := range ipSet.Prefixes() {
prefixes = append(prefixes, prefix.String())
}
return prefixes, nil
}
func parseDestination(dest string) (string, string, error) { func parseDestination(dest string) (string, string, error) {
var tokens []string var tokens []string
@ -520,6 +497,26 @@ func parseProtocol(protocol string) ([]int, bool, error) {
} }
} }
// expandSource returns a set of Source IPs that would be associated
// with the given src alias.
func (pol *ACLPolicy) expandSource(
src string,
machines types.Machines,
) ([]string, error) {
ipSet, err := pol.ExpandAlias(machines, src)
if err != nil {
return []string{}, err
}
prefixes := []string{}
for _, prefix := range ipSet.Prefixes() {
prefixes = append(prefixes, prefix.String())
}
return prefixes, nil
}
// expandalias has an input of either // expandalias has an input of either
// - a user // - a user
// - a group // - a group
@ -544,16 +541,16 @@ func (pol *ACLPolicy) ExpandAlias(
// if alias is a group // if alias is a group
if isGroup(alias) { if isGroup(alias) {
return pol.getIPsFromGroup(alias, machines) return pol.expandIPsFromGroup(alias, machines)
} }
// if alias is a tag // if alias is a tag
if isTag(alias) { if isTag(alias) {
return pol.getIPsFromTag(alias, machines) return pol.expandIPsFromTag(alias, machines)
} }
// if alias is a user // if alias is a user
if ips, err := pol.getIPsForUser(alias, machines); ips != nil { if ips, err := pol.expandIPsFromUser(alias, machines); ips != nil {
return ips, err return ips, err
} }
@ -567,12 +564,12 @@ func (pol *ACLPolicy) ExpandAlias(
// if alias is an IP // if alias is an IP
if ip, err := netip.ParseAddr(alias); err == nil { if ip, err := netip.ParseAddr(alias); err == nil {
return pol.getIPsFromSingleIP(ip, machines) return pol.expandIPsFromSingleIP(ip, machines)
} }
// if alias is an IP Prefix (CIDR) // if alias is an IP Prefix (CIDR)
if prefix, err := netip.ParsePrefix(alias); err == nil { if prefix, err := netip.ParsePrefix(alias); err == nil {
return pol.getIPsFromIPPrefix(prefix, machines) return pol.expandIPsFromIPPrefix(prefix, machines)
} }
log.Warn().Msgf("No IPs found with the alias %v", alias) log.Warn().Msgf("No IPs found with the alias %v", alias)
@ -591,7 +588,7 @@ func excludeCorrectlyTaggedNodes(
out := types.Machines{} out := types.Machines{}
tags := []string{} tags := []string{}
for tag := range aclPolicy.TagOwners { for tag := range aclPolicy.TagOwners {
owners, _ := getTagOwners(aclPolicy, user) owners, _ := expandOwnersFromTag(aclPolicy, user)
ns := append(owners, user) ns := append(owners, user)
if util.StringOrPrefixListContains(ns, user) { if util.StringOrPrefixListContains(ns, user) {
tags = append(tags, tag) tags = append(tags, tag)
@ -668,20 +665,9 @@ func expandPorts(portsStr string, isWild bool) (*[]tailcfg.PortRange, error) {
return &ports, nil return &ports, nil
} }
func filterMachinesByUser(machines types.Machines, user string) types.Machines { // expandOwnersFromTag will return a list of user. An owner can be either a user or a group
out := types.Machines{}
for _, machine := range machines {
if machine.User.Name == user {
out = append(out, machine)
}
}
return out
}
// getTagOwners will return a list of user. An owner can be either a user or a group
// a group cannot be composed of groups. // a group cannot be composed of groups.
func getTagOwners( func expandOwnersFromTag(
pol *ACLPolicy, pol *ACLPolicy,
tag string, tag string,
) ([]string, error) { ) ([]string, error) {
@ -696,7 +682,7 @@ func getTagOwners(
} }
for _, owner := range ows { for _, owner := range ows {
if isGroup(owner) { if isGroup(owner) {
gs, err := pol.getUsersInGroup(owner) gs, err := pol.expandUsersFromGroup(owner)
if err != nil { if err != nil {
return []string{}, err return []string{}, err
} }
@ -709,9 +695,9 @@ func getTagOwners(
return owners, nil return owners, nil
} }
// getUsersInGroup will return the list of user inside the group // expandUsersFromGroup will return the list of user inside the group
// after some validation. // after some validation.
func (pol *ACLPolicy) getUsersInGroup( func (pol *ACLPolicy) expandUsersFromGroup(
group string, group string,
) ([]string, error) { ) ([]string, error) {
users := []string{} users := []string{}
@ -745,13 +731,13 @@ func (pol *ACLPolicy) getUsersInGroup(
return users, nil return users, nil
} }
func (pol *ACLPolicy) getIPsFromGroup( func (pol *ACLPolicy) expandIPsFromGroup(
group string, group string,
machines types.Machines, machines types.Machines,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{} build := netipx.IPSetBuilder{}
users, err := pol.getUsersInGroup(group) users, err := pol.expandUsersFromGroup(group)
if err != nil { if err != nil {
return &netipx.IPSet{}, err return &netipx.IPSet{}, err
} }
@ -765,7 +751,7 @@ func (pol *ACLPolicy) getIPsFromGroup(
return build.IPSet() return build.IPSet()
} }
func (pol *ACLPolicy) getIPsFromTag( func (pol *ACLPolicy) expandIPsFromTag(
alias string, alias string,
machines types.Machines, machines types.Machines,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
@ -779,7 +765,7 @@ func (pol *ACLPolicy) getIPsFromTag(
} }
// find tag owners // find tag owners
owners, err := getTagOwners(pol, alias) owners, err := expandOwnersFromTag(pol, alias)
if err != nil { if err != nil {
if errors.Is(err, ErrInvalidTag) { if errors.Is(err, ErrInvalidTag) {
ipSet, _ := build.IPSet() ipSet, _ := build.IPSet()
@ -811,7 +797,7 @@ func (pol *ACLPolicy) getIPsFromTag(
return build.IPSet() return build.IPSet()
} }
func (pol *ACLPolicy) getIPsForUser( func (pol *ACLPolicy) expandIPsFromUser(
user string, user string,
machines types.Machines, machines types.Machines,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
@ -832,7 +818,7 @@ func (pol *ACLPolicy) getIPsForUser(
return build.IPSet() return build.IPSet()
} }
func (pol *ACLPolicy) getIPsFromSingleIP( func (pol *ACLPolicy) expandIPsFromSingleIP(
ip netip.Addr, ip netip.Addr,
machines types.Machines, machines types.Machines,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
@ -850,7 +836,7 @@ func (pol *ACLPolicy) getIPsFromSingleIP(
return build.IPSet() return build.IPSet()
} }
func (pol *ACLPolicy) getIPsFromIPPrefix( func (pol *ACLPolicy) expandIPsFromIPPrefix(
prefix netip.Prefix, prefix netip.Prefix,
machines types.Machines, machines types.Machines,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
@ -885,10 +871,10 @@ func isTag(str string) bool {
return strings.HasPrefix(str, "tag:") return strings.HasPrefix(str, "tag:")
} }
// getTags will return the tags of the current machine. // TagsOfMachine will return the tags of the current machine.
// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. // Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag.
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. // Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.
func (pol *ACLPolicy) GetTagsOfMachine( func (pol *ACLPolicy) TagsOfMachine(
machine types.Machine, machine types.Machine,
) ([]string, []string) { ) ([]string, []string) {
validTags := make([]string, 0) validTags := make([]string, 0)
@ -897,7 +883,7 @@ func (pol *ACLPolicy) GetTagsOfMachine(
validTagMap := make(map[string]bool) validTagMap := make(map[string]bool)
invalidTagMap := make(map[string]bool) invalidTagMap := make(map[string]bool)
for _, tag := range machine.HostInfo.RequestTags { for _, tag := range machine.HostInfo.RequestTags {
owners, err := getTagOwners(pol, tag) owners, err := expandOwnersFromTag(pol, tag)
if errors.Is(err, ErrInvalidTag) { if errors.Is(err, ErrInvalidTag) {
invalidTagMap[tag] = true invalidTagMap[tag] = true
@ -925,6 +911,17 @@ func (pol *ACLPolicy) GetTagsOfMachine(
return validTags, invalidTags return validTags, invalidTags
} }
func filterMachinesByUser(machines types.Machines, user string) types.Machines {
out := types.Machines{}
for _, machine := range machines {
if machine.User.Name == user {
out = append(out, machine)
}
}
return out
}
// FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine. // FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
func FilterMachinesByACL( func FilterMachinesByACL(
machine *types.Machine, machine *types.Machine,

View file

@ -690,7 +690,7 @@ func Test_expandGroup(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
viper.Set("oidc.strip_email_domain", test.args.stripEmail) viper.Set("oidc.strip_email_domain", test.args.stripEmail)
got, err := test.field.pol.getUsersInGroup( got, err := test.field.pol.expandUsersFromGroup(
test.args.group, test.args.group,
) )
@ -779,7 +779,7 @@ func Test_expandTagOwners(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got, err := getTagOwners( got, err := expandOwnersFromTag(
test.args.aclPolicy, test.args.aclPolicy,
test.args.tag, test.args.tag,
) )
@ -2022,7 +2022,7 @@ func Test_getTags(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine( gotValid, gotInvalid := test.args.aclPolicy.TagsOfMachine(
test.args.machine, test.args.machine,
) )
for _, valid := range gotValid { for _, valid := range gotValid {