diff --git a/acls.go b/acls.go index d698e26..414be39 100644 --- a/acls.go +++ b/acls.go @@ -1,30 +1,119 @@ package headscale import ( + "fmt" "io" "os" + "strings" "github.com/tailscale/hujson" + "inet.af/netaddr" + "tailscale.com/tailcfg" ) -const errorInvalidPolicy = Error("invalid policy") +const errorEmptyPolicy = Error("empty policy") +const errorInvalidAction = Error("invalid action") +const errorInvalidUserSection = Error("invalid user section") +const errorInvalidGroup = Error("invalid group") -func (h *Headscale) ParsePolicy(path string) (*ACLPolicy, error) { +func (h *Headscale) LoadPolicy(path string) error { policyFile, err := os.Open(path) if err != nil { - return nil, err + return err } defer policyFile.Close() var policy ACLPolicy b, err := io.ReadAll(policyFile) if err != nil { - return nil, err + return err } err = hujson.Unmarshal(b, &policy) if policy.IsZero() { - return nil, errorInvalidPolicy + return errorEmptyPolicy } - return &policy, err + h.aclPolicy = &policy + return err +} + +func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) { + rules := []tailcfg.FilterRule{} + + for i, a := range h.aclPolicy.ACLs { + if a.Action != "accept" { + return nil, errorInvalidAction + } + + r := tailcfg.FilterRule{} + + srcIPs := []string{} + for j, u := range a.Users { + fmt.Printf("acl %d, user %d: ", i, j) + srcs, err := h.generateAclPolicySrcIP(u) + fmt.Printf(" -> %s\n", err) + if err != nil { + return nil, err + } + srcIPs = append(srcIPs, *srcs...) + } + r.SrcIPs = srcIPs + + } + + return &rules, nil +} + +func (h *Headscale) generateAclPolicySrcIP(u string) (*[]string, error) { + if u == "*" { + fmt.Printf("%s -> wildcard", u) + return &[]string{"*"}, nil + } + + if strings.HasPrefix(u, "group:") { + fmt.Printf("%s -> group", u) + if _, ok := h.aclPolicy.Groups[u]; !ok { + return nil, errorInvalidGroup + } + return nil, nil + } + + if strings.HasPrefix(u, "tag:") { + fmt.Printf("%s -> tag", u) + return nil, nil + } + + n, err := h.GetNamespace(u) + if err == nil { + fmt.Printf("%s -> namespace %s", u, n.Name) + nodes, err := h.ListMachinesInNamespace(n.Name) + if err != nil { + return nil, err + } + ips := []string{} + for _, n := range *nodes { + ips = append(ips, n.IPAddress) + } + return &ips, nil + } + + if h, ok := h.aclPolicy.Hosts[u]; ok { + fmt.Printf("%s -> host %s", u, h) + return &[]string{h.String()}, nil + } + + ip, err := netaddr.ParseIP(u) + if err == nil { + fmt.Printf(" %s -> ip %s", u, ip) + return &[]string{ip.String()}, nil + } + + cidr, err := netaddr.ParseIPPrefix(u) + if err == nil { + fmt.Printf("%s -> cidr %s", u, cidr) + return &[]string{cidr.String()}, nil + } + + fmt.Printf("%s: cannot be mapped to anything\n", u) + return nil, errorInvalidUserSection } diff --git a/acls_test.go b/acls_test.go index 6b01242..fe77932 100644 --- a/acls_test.go +++ b/acls_test.go @@ -5,29 +5,65 @@ import ( ) func (s *Suite) TestWrongPath(c *check.C) { - _, err := h.ParsePolicy("asdfg") + err := h.LoadPolicy("asdfg") c.Assert(err, check.NotNil) } func (s *Suite) TestBrokenHuJson(c *check.C) { - _, err := h.ParsePolicy("./tests/acls/broken.hujson") + err := h.LoadPolicy("./tests/acls/broken.hujson") c.Assert(err, check.NotNil) } func (s *Suite) TestInvalidPolicyHuson(c *check.C) { - _, err := h.ParsePolicy("./tests/acls/invalid.hujson") + err := h.LoadPolicy("./tests/acls/invalid.hujson") c.Assert(err, check.NotNil) - c.Assert(err, check.Equals, errorInvalidPolicy) + c.Assert(err, check.Equals, errorEmptyPolicy) } -func (s *Suite) TestValidCheckHosts(c *check.C) { - p, err := h.ParsePolicy("./tests/acls/acl_policy_1.hujson") +func (s *Suite) TestParseHosts(c *check.C) { + var hs Hosts + err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`)) + c.Assert(hs, check.NotNil) c.Assert(err, check.IsNil) - c.Assert(p, check.NotNil) - c.Assert(p.IsZero(), check.Equals, false) - - hosts, err := p.GetHosts() - c.Assert(err, check.IsNil) - c.Assert(*hosts, check.HasLen, 2) +} + +func (s *Suite) TestParseInvalidCIDR(c *check.C) { + var hs Hosts + err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`)) + c.Assert(hs, check.IsNil) + c.Assert(err, check.NotNil) +} + +func (s *Suite) TestCheckLoaded(c *check.C) { + err := h.LoadPolicy("./tests/acls/acl_policy_1.hujson") + c.Assert(err, check.IsNil) + c.Assert(h.aclPolicy, check.NotNil) +} + +func (s *Suite) TestValidCheckParsedHosts(c *check.C) { + err := h.LoadPolicy("./tests/acls/acl_policy_1.hujson") + c.Assert(err, check.IsNil) + c.Assert(h.aclPolicy, check.NotNil) + c.Assert(h.aclPolicy.IsZero(), check.Equals, false) + c.Assert(h.aclPolicy.Hosts, check.HasLen, 2) +} + +func (s *Suite) TestRuleInvalidGeneration(c *check.C) { + err := h.LoadPolicy("./tests/acls/acl_policy_invalid.hujson") + c.Assert(err, check.IsNil) + + rules, err := h.generateACLRules() + c.Assert(err, check.NotNil) + c.Assert(rules, check.IsNil) +} + +func (s *Suite) TestRuleGeneration(c *check.C) { + err := h.LoadPolicy("./tests/acls/acl_policy_1.hujson") + c.Assert(err, check.IsNil) + + rules, err := h.generateACLRules() + c.Assert(err, check.IsNil) + c.Assert(rules, check.NotNil) + } diff --git a/acls_types.go b/acls_types.go index 08b383c..d385362 100644 --- a/acls_types.go +++ b/acls_types.go @@ -3,6 +3,7 @@ package headscale import ( "strings" + "github.com/tailscale/hujson" "inet.af/netaddr" ) @@ -22,12 +23,9 @@ type ACL struct { type Groups map[string][]string -type Hosts map[string]string +type Hosts map[string]netaddr.IPPrefix -type TagOwners struct { - TagMontrealWebserver []string `json:"tag:montreal-webserver"` - TagAPIServer []string `json:"tag:api-server"` -} +type TagOwners map[string][]string type ACLTest struct { User string `json:"User"` @@ -35,6 +33,27 @@ type ACLTest struct { Deny []string `json:"Deny,omitempty"` } +func (h *Hosts) UnmarshalJSON(data []byte) error { + hosts := Hosts{} + hs := make(map[string]string) + err := hujson.Unmarshal(data, &hs) + if err != nil { + return err + } + for k, v := range hs { + if !strings.Contains(v, "/") { + v = v + "/32" + } + prefix, err := netaddr.ParseIPPrefix(v) + if err != nil { + return err + } + hosts[k] = prefix + } + *h = hosts + return nil +} + // IsZero is perhaps a bit naive here func (p ACLPolicy) IsZero() bool { if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 { @@ -42,18 +61,3 @@ func (p ACLPolicy) IsZero() bool { } return false } - -func (p ACLPolicy) GetHosts() (*map[string]netaddr.IPPrefix, error) { - hosts := make(map[string]netaddr.IPPrefix) - for k, v := range p.Hosts { - if !strings.Contains(v, "/") { - v = v + "/32" - } - prefix, err := netaddr.ParseIPPrefix(v) - if err != nil { - return nil, err - } - hosts[k] = prefix - } - return &hosts, nil -} diff --git a/app.go b/app.go index 0cdc310..4775c6e 100644 --- a/app.go +++ b/app.go @@ -49,6 +49,8 @@ type Headscale struct { publicKey *wgkey.Key privateKey *wgkey.Private + aclPolicy *ACLPolicy + pollMu sync.Mutex clientsPolling map[uint64]chan []byte // this is by all means a hackity hack }