diff --git a/acls.go b/acls.go index 414be39..8ba8618 100644 --- a/acls.go +++ b/acls.go @@ -1,11 +1,15 @@ package headscale import ( + "encoding/json" "fmt" "io" + "log" "os" + "strconv" "strings" + "github.com/davecgh/go-spew/spew" "github.com/tailscale/hujson" "inet.af/netaddr" "tailscale.com/tailcfg" @@ -15,6 +19,9 @@ const errorEmptyPolicy = Error("empty policy") const errorInvalidAction = Error("invalid action") const errorInvalidUserSection = Error("invalid user section") const errorInvalidGroup = Error("invalid group") +const errorInvalidTag = Error("invalid tag") +const errorInvalidNamespace = Error("invalid namespace") +const errorInvalidPortFormat = Error("invalid port format") func (h *Headscale) LoadPolicy(path string) error { policyFile, err := os.Open(path) @@ -59,33 +66,143 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) { } r.SrcIPs = srcIPs + destPorts := []tailcfg.NetPortRange{} + for j, d := range a.Ports { + fmt.Printf("acl %d, port %d: ", i, j) + dests, err := h.generateAclPolicyDestPorts(d) + fmt.Printf(" -> %s\n", err) + if err != nil { + return nil, err + } + destPorts = append(destPorts, *dests...) + } + + rules = append(rules, tailcfg.FilterRule{ + SrcIPs: srcIPs, + DstPorts: destPorts, + }) } + // fmt.Println(rules) + spew.Dump(rules) return &rules, nil } func (h *Headscale) generateAclPolicySrcIP(u string) (*[]string, error) { - if u == "*" { - fmt.Printf("%s -> wildcard", u) + return h.expandAlias(u) +} + +func (h *Headscale) generateAclPolicyDestPorts(d string) (*[]tailcfg.NetPortRange, error) { + tokens := strings.Split(d, ":") + if len(tokens) < 2 || len(tokens) > 3 { + return nil, errorInvalidPortFormat + } + + var alias string + // We can have here stuff like: + // git-server:* + // 192.168.1.0/24:22 + // tag:montreal-webserver:80,443 + // tag:api-server:443 + // example-host-1:* + if len(tokens) == 2 { + alias = tokens[0] + } else { + alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) + } + + expanded, err := h.expandAlias(alias) + if err != nil { + return nil, err + } + ports, err := h.expandPorts(tokens[len(tokens)-1]) + if err != nil { + return nil, err + } + + dests := []tailcfg.NetPortRange{} + for _, d := range *expanded { + for _, p := range *ports { + pr := tailcfg.NetPortRange{ + IP: d, + Ports: p, + } + dests = append(dests, pr) + } + } + return &dests, nil +} + +func (h *Headscale) expandAlias(s string) (*[]string, error) { + if s == "*" { + fmt.Printf("%s -> wildcard", s) return &[]string{"*"}, nil } - if strings.HasPrefix(u, "group:") { - fmt.Printf("%s -> group", u) - if _, ok := h.aclPolicy.Groups[u]; !ok { + if strings.HasPrefix(s, "group:") { + fmt.Printf("%s -> group", s) + if _, ok := h.aclPolicy.Groups[s]; !ok { return nil, errorInvalidGroup } - return nil, nil + ips := []string{} + for _, n := range h.aclPolicy.Groups[s] { + nodes, err := h.ListMachinesInNamespace(n) + if err != nil { + return nil, errorInvalidNamespace + } + for _, node := range *nodes { + ips = append(ips, node.IPAddress) + } + } + return &ips, nil } - if strings.HasPrefix(u, "tag:") { - fmt.Printf("%s -> tag", u) - return nil, nil + if strings.HasPrefix(s, "tag:") { + fmt.Printf("%s -> tag", s) + if _, ok := h.aclPolicy.TagOwners[s]; !ok { + return nil, errorInvalidTag + } + + // This will have HORRIBLE performance. + // We need to change the data model to better store tags + db, err := h.db() + if err != nil { + log.Printf("Cannot open DB: %s", err) + return nil, err + } + machines := []Machine{} + if err = db.Where("registered").Find(&machines).Error; err != nil { + log.Printf("Error accessing db: %s", err) + return nil, err + } + ips := []string{} + for _, m := range machines { + hostinfo := tailcfg.Hostinfo{} + if len(m.HostInfo) != 0 { + hi, err := m.HostInfo.MarshalJSON() + if err != nil { + return nil, err + } + err = json.Unmarshal(hi, &hostinfo) + if err != nil { + return nil, err + } + + // FIXME: Check TagOwners allows this + for _, t := range hostinfo.RequestTags { + if s[4:] == t { + ips = append(ips, m.IPAddress) + break + } + } + } + } + return &ips, nil } - n, err := h.GetNamespace(u) + n, err := h.GetNamespace(s) if err == nil { - fmt.Printf("%s -> namespace %s", u, n.Name) + fmt.Printf("%s -> namespace %s", s, n.Name) nodes, err := h.ListMachinesInNamespace(n.Name) if err != nil { return nil, err @@ -97,23 +214,60 @@ func (h *Headscale) generateAclPolicySrcIP(u string) (*[]string, error) { return &ips, nil } - if h, ok := h.aclPolicy.Hosts[u]; ok { - fmt.Printf("%s -> host %s", u, h) + if h, ok := h.aclPolicy.Hosts[s]; ok { + fmt.Printf("%s -> host %s", s, h) return &[]string{h.String()}, nil } - ip, err := netaddr.ParseIP(u) + ip, err := netaddr.ParseIP(s) if err == nil { - fmt.Printf(" %s -> ip %s", u, ip) + fmt.Printf(" %s -> ip %s", s, ip) return &[]string{ip.String()}, nil } - cidr, err := netaddr.ParseIPPrefix(u) + cidr, err := netaddr.ParseIPPrefix(s) if err == nil { - fmt.Printf("%s -> cidr %s", u, cidr) + fmt.Printf("%s -> cidr %s", s, cidr) return &[]string{cidr.String()}, nil } - fmt.Printf("%s: cannot be mapped to anything\n", u) + fmt.Printf("%s: cannot be mapped to anything\n", s) return nil, errorInvalidUserSection } + +func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) { + if s == "*" { + return &[]tailcfg.PortRange{{First: 0, Last: 65535}}, nil + } + + ports := []tailcfg.PortRange{} + for _, p := range strings.Split(s, ",") { + rang := strings.Split(p, "-") + if len(rang) == 1 { + pi, err := strconv.ParseUint(rang[0], 10, 16) + if err != nil { + return nil, err + } + ports = append(ports, tailcfg.PortRange{ + First: uint16(pi), + Last: uint16(pi), + }) + } else if len(rang) == 2 { + start, err := strconv.ParseUint(rang[0], 10, 16) + if err != nil { + return nil, err + } + last, err := strconv.ParseUint(rang[1], 10, 16) + if err != nil { + return nil, err + } + ports = append(ports, tailcfg.PortRange{ + First: uint16(start), + Last: uint16(last), + }) + } else { + return nil, errorInvalidPortFormat + } + } + return &ports, nil +} diff --git a/acls_test.go b/acls_test.go index fe77932..97f0d33 100644 --- a/acls_test.go +++ b/acls_test.go @@ -58,12 +58,21 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { c.Assert(rules, check.IsNil) } -func (s *Suite) TestRuleGeneration(c *check.C) { - err := h.LoadPolicy("./tests/acls/acl_policy_1.hujson") +func (s *Suite) TestBasicRule(c *check.C) { + err := h.LoadPolicy("./tests/acls/acl_policy_basic_1.hujson") c.Assert(err, check.IsNil) rules, err := h.generateACLRules() c.Assert(err, check.IsNil) - c.Assert(rules, check.NotNil) - + c.Assert(rules, check.IsNil) } + +// func (s *Suite) TestRuleGeneration(c *check.C) { +// err := h.LoadPolicy("./tests/acls/acl_policy_1.hujson") +// c.Assert(err, check.IsNil) + +// rules, err := h.generateACLRules() +// c.Assert(err, check.IsNil) +// c.Assert(rules, check.NotNil) + +// }