diff --git a/acls.go b/acls.go index 281ea2c..6cae7e0 100644 --- a/acls.go +++ b/acls.go @@ -22,7 +22,8 @@ const errorInvalidTag = Error("invalid tag") const errorInvalidNamespace = Error("invalid namespace") const errorInvalidPortFormat = Error("invalid port format") -func (h *Headscale) LoadAclPolicy(path string) error { +// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules +func (h *Headscale) LoadACLPolicy(path string) error { policyFile, err := os.Open(path) if err != nil { return err @@ -35,6 +36,9 @@ func (h *Headscale) LoadAclPolicy(path string) error { return err } err = hujson.Unmarshal(b, &policy) + if err != nil { + return err + } if policy.IsZero() { return errorEmptyPolicy } @@ -61,7 +65,7 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) { srcIPs := []string{} for j, u := range a.Users { fmt.Printf("acl %d, user %d: ", i, j) - srcs, err := h.generateAclPolicySrcIP(u) + srcs, err := h.generateACLPolicySrcIP(u) fmt.Printf(" -> %s\n", err) if err != nil { return nil, err @@ -73,7 +77,7 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) { destPorts := []tailcfg.NetPortRange{} for j, d := range a.Ports { fmt.Printf("acl %d, port %d: ", i, j) - dests, err := h.generateAclPolicyDestPorts(d) + dests, err := h.generateACLPolicyDestPorts(d) fmt.Printf(" -> %s\n", err) if err != nil { return nil, err @@ -90,11 +94,11 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) { return &rules, nil } -func (h *Headscale) generateAclPolicySrcIP(u string) (*[]string, error) { +func (h *Headscale) generateACLPolicySrcIP(u string) (*[]string, error) { return h.expandAlias(u) } -func (h *Headscale) generateAclPolicyDestPorts(d string) (*[]tailcfg.NetPortRange, error) { +func (h *Headscale) generateACLPolicyDestPorts(d string) (*[]tailcfg.NetPortRange, error) { tokens := strings.Split(d, ":") if len(tokens) < 2 || len(tokens) > 3 { return nil, errorInvalidPortFormat diff --git a/acls_test.go b/acls_test.go index 99c2909..f80349c 100644 --- a/acls_test.go +++ b/acls_test.go @@ -5,18 +5,18 @@ import ( ) func (s *Suite) TestWrongPath(c *check.C) { - err := h.LoadAclPolicy("asdfg") + err := h.LoadACLPolicy("asdfg") c.Assert(err, check.NotNil) } func (s *Suite) TestBrokenHuJson(c *check.C) { - err := h.LoadAclPolicy("./tests/acls/broken.hujson") + err := h.LoadACLPolicy("./tests/acls/broken.hujson") c.Assert(err, check.NotNil) } func (s *Suite) TestInvalidPolicyHuson(c *check.C) { - err := h.LoadAclPolicy("./tests/acls/invalid.hujson") + err := h.LoadACLPolicy("./tests/acls/invalid.hujson") c.Assert(err, check.NotNil) c.Assert(err, check.Equals, errorEmptyPolicy) } @@ -36,13 +36,13 @@ func (s *Suite) TestParseInvalidCIDR(c *check.C) { } func (s *Suite) TestCheckLoaded(c *check.C) { - err := h.LoadAclPolicy("./tests/acls/acl_policy_1.hujson") + err := h.LoadACLPolicy("./tests/acls/acl_policy_1.hujson") c.Assert(err, check.IsNil) c.Assert(h.aclPolicy, check.NotNil) } func (s *Suite) TestValidCheckParsedHosts(c *check.C) { - err := h.LoadAclPolicy("./tests/acls/acl_policy_1.hujson") + err := h.LoadACLPolicy("./tests/acls/acl_policy_1.hujson") c.Assert(err, check.IsNil) c.Assert(h.aclPolicy, check.NotNil) c.Assert(h.aclPolicy.IsZero(), check.Equals, false) @@ -50,7 +50,7 @@ func (s *Suite) TestValidCheckParsedHosts(c *check.C) { } func (s *Suite) TestRuleInvalidGeneration(c *check.C) { - err := h.LoadAclPolicy("./tests/acls/acl_policy_invalid.hujson") + err := h.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson") c.Assert(err, check.IsNil) rules, err := h.generateACLRules() @@ -59,7 +59,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { } func (s *Suite) TestBasicRule(c *check.C) { - err := h.LoadAclPolicy("./tests/acls/acl_policy_basic_1.hujson") + err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson") c.Assert(err, check.IsNil) rules, err := h.generateACLRules() @@ -68,7 +68,7 @@ func (s *Suite) TestBasicRule(c *check.C) { } func (s *Suite) TestPortRange(c *check.C) { - err := h.LoadAclPolicy("./tests/acls/acl_policy_basic_range.hujson") + err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson") c.Assert(err, check.IsNil) rules, err := h.generateACLRules() @@ -82,7 +82,7 @@ func (s *Suite) TestPortRange(c *check.C) { } func (s *Suite) TestPortWildcard(c *check.C) { - err := h.LoadAclPolicy("./tests/acls/acl_policy_basic_wildcards.hujson") + err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson") c.Assert(err, check.IsNil) rules, err := h.generateACLRules() @@ -126,7 +126,7 @@ func (s *Suite) TestPortNamespace(c *check.C) { } db.Save(&m) - err = h.LoadAclPolicy("./tests/acls/acl_policy_basic_namespace_as_user.hujson") + err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_namespace_as_user.hujson") c.Assert(err, check.IsNil) rules, err := h.generateACLRules() @@ -171,7 +171,7 @@ func (s *Suite) TestPortGroup(c *check.C) { } db.Save(&m) - err = h.LoadAclPolicy("./tests/acls/acl_policy_basic_groups.hujson") + err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson") c.Assert(err, check.IsNil) rules, err := h.generateACLRules() diff --git a/acls_types.go b/acls_types.go index d385362..01e42d5 100644 --- a/acls_types.go +++ b/acls_types.go @@ -7,6 +7,7 @@ import ( "inet.af/netaddr" ) +// ACLPolicy represents a Tailscale ACL Policy type ACLPolicy struct { Groups Groups `json:"Groups"` Hosts Hosts `json:"Hosts"` @@ -15,24 +16,30 @@ type ACLPolicy struct { Tests []ACLTest `json:"Tests"` } +// ACL is a basic rule for the ACL Policy type ACL struct { Action string `json:"Action"` Users []string `json:"Users"` Ports []string `json:"Ports"` } +// Groups references a series of alias in the ACL rules type Groups map[string][]string +// Hosts are alias for IP addresses or subnets type Hosts map[string]netaddr.IPPrefix +// TagOwners specify what users (namespaces?) are allow to use certain tags type TagOwners map[string][]string +// ACLTest is not implemented, but should be use to check if a certain rule is allowed type ACLTest struct { User string `json:"User"` Allow []string `json:"Allow"` Deny []string `json:"Deny,omitempty"` } +// UnmarshalJSON allows to parse the Hosts directly into netaddr objects func (h *Hosts) UnmarshalJSON(data []byte) error { hosts := Hosts{} hs := make(map[string]string) diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index c606b6d..a872ec5 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -121,7 +121,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { } // We are doing this here, as in the future could be cool to have it also hot-reload - err = h.LoadAclPolicy(absPath(viper.GetString("acl_policy_path"))) + err = h.LoadACLPolicy(absPath(viper.GetString("acl_policy_path"))) if err != nil { log.Printf("Could not load the ACL policy: %s", err) }