clear up the acl function naming

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-04-26 11:19:47 +02:00 committed by Juan Font
parent 6de53e2f8d
commit 5bbbe437df
2 changed files with 34 additions and 25 deletions

45
acls.go
View file

@ -128,7 +128,7 @@ func (h *Headscale) UpdateACLRules() error {
return errEmptyPolicy return errEmptyPolicy
} }
rules, err := generateACLRules(machines, *h.aclPolicy, h.cfg.OIDC.StripEmaildomain) rules, err := generateFilterRules(machines, *h.aclPolicy, h.cfg.OIDC.StripEmaildomain)
if err != nil { if err != nil {
return err return err
} }
@ -224,24 +224,29 @@ func expandACLPeerAddr(srcIP string) []string {
return []string{srcIP} return []string{srcIP}
} }
func generateACLRules( // generateFilterRules takes a set of machines and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func generateFilterRules(
machines []Machine, machines []Machine,
aclPolicy ACLPolicy, pol ACLPolicy,
stripEmaildomain bool, stripEmaildomain bool,
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{} rules := []tailcfg.FilterRule{}
for index, acl := range aclPolicy.ACLs { for index, acl := range pol.ACLs {
if acl.Action != "accept" { if acl.Action != "accept" {
return nil, errInvalidAction return nil, errInvalidAction
} }
srcIPs := []string{} srcIPs := []string{}
for innerIndex, src := range acl.Sources { for srcIndex, src := range acl.Sources {
srcs, err := generateACLPolicySrc(machines, aclPolicy, src, stripEmaildomain) srcs, err := pol.getIPsFromSource(src, machines, stripEmaildomain)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, Source %d", index, innerIndex) Interface("src", src).
Int("ACL index", index).
Int("Src index", srcIndex).
Msgf("Error parsing ACL")
return nil, err return nil, err
} }
@ -257,17 +262,19 @@ func generateACLRules(
} }
destPorts := []tailcfg.NetPortRange{} destPorts := []tailcfg.NetPortRange{}
for innerIndex, dest := range acl.Destinations { for destIndex, dest := range acl.Destinations {
dests, err := generateACLPolicyDest( dests, err := pol.getNetPortRangeFromDestination(
machines,
aclPolicy,
dest, dest,
machines,
needsWildcard, needsWildcard,
stripEmaildomain, stripEmaildomain,
) )
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, Destination %d", index, innerIndex) Interface("dest", dest).
Int("ACL index", index).
Int("dest index", destIndex).
Msgf("Error parsing ACL")
return nil, err return nil, err
} }
@ -388,19 +395,21 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
}, nil }, nil
} }
func generateACLPolicySrc( // getIPsFromSource returns a set of Source IPs that would be associated
machines []Machine, // with the given src alias.
pol ACLPolicy, func (pol *ACLPolicy) getIPsFromSource(
src string, src string,
machines []Machine,
stripEmaildomain bool, stripEmaildomain bool,
) ([]string, error) { ) ([]string, error) {
return pol.expandAlias(machines, src, stripEmaildomain) return pol.expandAlias(machines, src, stripEmaildomain)
} }
func generateACLPolicyDest( // getNetPortRangeFromDestination returns a set of tailcfg.NetPortRange
machines []Machine, // which are associated with the dest alias.
pol ACLPolicy, func (pol *ACLPolicy) getNetPortRangeFromDestination(
dest string, dest string,
machines []Machine,
needsWildcard bool, needsWildcard bool,
stripEmaildomain bool, stripEmaildomain bool,
) ([]tailcfg.NetPortRange, error) { ) ([]tailcfg.NetPortRange, error) {

View file

@ -54,7 +54,7 @@ func (s *Suite) TestBasicRule(c *check.C) {
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) rules, err := generateFilterRules([]Machine{}, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
} }
@ -411,7 +411,7 @@ func (s *Suite) TestPortRange(c *check.C) {
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) rules, err := generateFilterRules([]Machine{}, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -425,7 +425,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) {
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_protocols.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_protocols.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) rules, err := generateFilterRules([]Machine{}, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -439,7 +439,7 @@ func (s *Suite) TestPortWildcard(c *check.C) {
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) rules, err := generateFilterRules([]Machine{}, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -455,7 +455,7 @@ func (s *Suite) TestPortWildcardYAML(c *check.C) {
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.yaml") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.yaml")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) rules, err := generateFilterRules([]Machine{}, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -498,7 +498,7 @@ func (s *Suite) TestPortUser(c *check.C) {
machines, err := app.ListMachines() machines, err := app.ListMachines()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := generateACLRules(machines, *app.aclPolicy, false) rules, err := generateFilterRules(machines, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -541,7 +541,7 @@ func (s *Suite) TestPortGroup(c *check.C) {
machines, err := app.ListMachines() machines, err := app.ListMachines()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := generateACLRules(machines, *app.aclPolicy, false) rules, err := generateFilterRules(machines, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)