headscale/acls.go

269 lines
5.7 KiB
Go
Raw Normal View History

2021-07-03 03:55:32 -06:00
package headscale
import (
"encoding/json"
2021-07-03 09:31:32 -06:00
"fmt"
2021-07-03 03:55:32 -06:00
"io"
"log"
2021-07-03 03:55:32 -06:00
"os"
"strconv"
2021-07-03 09:31:32 -06:00
"strings"
2021-07-03 03:55:32 -06:00
"github.com/tailscale/hujson"
2021-07-03 09:31:32 -06:00
"inet.af/netaddr"
"tailscale.com/tailcfg"
2021-07-03 03:55:32 -06:00
)
2021-07-03 09:31:32 -06:00
const errorEmptyPolicy = Error("empty policy")
const errorInvalidAction = Error("invalid action")
const errorInvalidUserSection = Error("invalid user section")
const errorInvalidGroup = Error("invalid group")
const errorInvalidTag = Error("invalid tag")
const errorInvalidNamespace = Error("invalid namespace")
const errorInvalidPortFormat = Error("invalid port format")
2021-07-03 03:55:32 -06:00
2021-07-04 05:33:00 -06:00
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules
func (h *Headscale) LoadACLPolicy(path string) error {
2021-07-03 03:55:32 -06:00
policyFile, err := os.Open(path)
if err != nil {
2021-07-03 09:31:32 -06:00
return err
2021-07-03 03:55:32 -06:00
}
defer policyFile.Close()
var policy ACLPolicy
b, err := io.ReadAll(policyFile)
if err != nil {
2021-07-03 09:31:32 -06:00
return err
2021-07-03 03:55:32 -06:00
}
err = hujson.Unmarshal(b, &policy)
2021-07-04 05:33:00 -06:00
if err != nil {
return err
}
2021-07-03 03:55:32 -06:00
if policy.IsZero() {
2021-07-03 09:31:32 -06:00
return errorEmptyPolicy
2021-07-03 03:55:32 -06:00
}
2021-07-03 09:31:32 -06:00
h.aclPolicy = &policy
2021-07-04 05:24:05 -06:00
rules, err := h.generateACLRules()
if err != nil {
return err
}
h.aclRules = rules
return nil
2021-07-03 09:31:32 -06:00
}
func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
for i, a := range h.aclPolicy.ACLs {
if a.Action != "accept" {
return nil, errorInvalidAction
}
r := tailcfg.FilterRule{}
srcIPs := []string{}
for j, u := range a.Users {
2021-07-04 05:33:00 -06:00
srcs, err := h.generateACLPolicySrcIP(u)
2021-07-03 09:31:32 -06:00
if err != nil {
2021-07-04 05:47:59 -06:00
log.Printf("Error parsing ACL %d, User %d", i, j)
2021-07-03 09:31:32 -06:00
return nil, err
}
srcIPs = append(srcIPs, *srcs...)
}
r.SrcIPs = srcIPs
destPorts := []tailcfg.NetPortRange{}
for j, d := range a.Ports {
2021-07-04 05:33:00 -06:00
dests, err := h.generateACLPolicyDestPorts(d)
if err != nil {
2021-07-04 05:47:59 -06:00
log.Printf("Error parsing ACL %d, Port %d", i, j)
return nil, err
}
destPorts = append(destPorts, *dests...)
}
rules = append(rules, tailcfg.FilterRule{
SrcIPs: srcIPs,
DstPorts: destPorts,
})
2021-07-03 09:31:32 -06:00
}
return &rules, nil
}
2021-07-04 05:33:00 -06:00
func (h *Headscale) generateACLPolicySrcIP(u string) (*[]string, error) {
return h.expandAlias(u)
}
2021-07-04 05:33:00 -06:00
func (h *Headscale) generateACLPolicyDestPorts(d string) (*[]tailcfg.NetPortRange, error) {
tokens := strings.Split(d, ":")
if len(tokens) < 2 || len(tokens) > 3 {
return nil, errorInvalidPortFormat
}
var alias string
// We can have here stuff like:
// git-server:*
// 192.168.1.0/24:22
// tag:montreal-webserver:80,443
// tag:api-server:443
// example-host-1:*
if len(tokens) == 2 {
alias = tokens[0]
} else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
}
expanded, err := h.expandAlias(alias)
if err != nil {
return nil, err
}
ports, err := h.expandPorts(tokens[len(tokens)-1])
if err != nil {
return nil, err
}
dests := []tailcfg.NetPortRange{}
for _, d := range *expanded {
for _, p := range *ports {
pr := tailcfg.NetPortRange{
IP: d,
Ports: p,
}
dests = append(dests, pr)
}
}
return &dests, nil
}
func (h *Headscale) expandAlias(s string) (*[]string, error) {
if s == "*" {
2021-07-03 09:31:32 -06:00
return &[]string{"*"}, nil
}
if strings.HasPrefix(s, "group:") {
if _, ok := h.aclPolicy.Groups[s]; !ok {
2021-07-03 09:31:32 -06:00
return nil, errorInvalidGroup
}
ips := []string{}
for _, n := range h.aclPolicy.Groups[s] {
nodes, err := h.ListMachinesInNamespace(n)
if err != nil {
return nil, errorInvalidNamespace
}
for _, node := range *nodes {
ips = append(ips, node.IPAddress)
}
}
return &ips, nil
2021-07-03 09:31:32 -06:00
}
if strings.HasPrefix(s, "tag:") {
if _, ok := h.aclPolicy.TagOwners[s]; !ok {
return nil, errorInvalidTag
}
// This will have HORRIBLE performance.
// We need to change the data model to better store tags
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return nil, err
}
machines := []Machine{}
if err = db.Where("registered").Find(&machines).Error; err != nil {
return nil, err
}
ips := []string{}
for _, m := range machines {
hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
err = json.Unmarshal(hi, &hostinfo)
if err != nil {
return nil, err
}
// FIXME: Check TagOwners allows this
for _, t := range hostinfo.RequestTags {
if s[4:] == t {
ips = append(ips, m.IPAddress)
break
}
}
}
}
return &ips, nil
2021-07-03 09:31:32 -06:00
}
n, err := h.GetNamespace(s)
2021-07-03 09:31:32 -06:00
if err == nil {
nodes, err := h.ListMachinesInNamespace(n.Name)
if err != nil {
return nil, err
}
ips := []string{}
for _, n := range *nodes {
ips = append(ips, n.IPAddress)
}
return &ips, nil
}
if h, ok := h.aclPolicy.Hosts[s]; ok {
2021-07-03 09:31:32 -06:00
return &[]string{h.String()}, nil
}
ip, err := netaddr.ParseIP(s)
2021-07-03 09:31:32 -06:00
if err == nil {
return &[]string{ip.String()}, nil
}
cidr, err := netaddr.ParseIPPrefix(s)
2021-07-03 09:31:32 -06:00
if err == nil {
return &[]string{cidr.String()}, nil
}
return nil, errorInvalidUserSection
2021-07-03 03:55:32 -06:00
}
func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
if s == "*" {
return &[]tailcfg.PortRange{{First: 0, Last: 65535}}, nil
}
ports := []tailcfg.PortRange{}
for _, p := range strings.Split(s, ",") {
rang := strings.Split(p, "-")
if len(rang) == 1 {
pi, err := strconv.ParseUint(rang[0], 10, 16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(pi),
Last: uint16(pi),
})
} else if len(rang) == 2 {
start, err := strconv.ParseUint(rang[0], 10, 16)
if err != nil {
return nil, err
}
last, err := strconv.ParseUint(rang[1], 10, 16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(start),
Last: uint16(last),
})
} else {
return nil, errorInvalidPortFormat
}
}
return &ports, nil
}