diff --git a/acls.go b/acls.go index c7376ee..799db95 100644 --- a/acls.go +++ b/acls.go @@ -23,6 +23,7 @@ const ( errInvalidGroup = Error("invalid group") errInvalidTag = Error("invalid tag") errInvalidPortFormat = Error("invalid port format") + errWildcardIsNeeded = Error("wildcard as port is required for the procotol") ) const ( @@ -134,9 +135,17 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { srcIPs = append(srcIPs, srcs...) } + protocols, needsWildcard, err := parseProtocol(acl.Protocol) + if err != nil { + log.Error(). + Msgf("Error parsing ACL %d. protocol unknown %s", index, acl.Protocol) + + return nil, err + } + destPorts := []tailcfg.NetPortRange{} for innerIndex, dest := range acl.Destinations { - dests, err := h.generateACLPolicyDest(machines, *h.aclPolicy, dest) + dests, err := h.generateACLPolicyDest(machines, *h.aclPolicy, dest, needsWildcard) if err != nil { log.Error(). Msgf("Error parsing ACL %d, Destination %d", index, innerIndex) @@ -149,6 +158,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { rules = append(rules, tailcfg.FilterRule{ SrcIPs: srcIPs, DstPorts: destPorts, + IPProto: protocols, }) } @@ -167,6 +177,7 @@ func (h *Headscale) generateACLPolicyDest( machines []Machine, aclPolicy ACLPolicy, dest string, + needsWildcard bool, ) ([]tailcfg.NetPortRange, error) { tokens := strings.Split(dest, ":") if len(tokens) < expectedTokenItems || len(tokens) > 3 { @@ -195,7 +206,7 @@ func (h *Headscale) generateACLPolicyDest( if err != nil { return nil, err } - ports, err := expandPorts(tokens[len(tokens)-1]) + ports, err := expandPorts(tokens[len(tokens)-1], needsWildcard) if err != nil { return nil, err } @@ -214,6 +225,54 @@ func (h *Headscale) generateACLPolicyDest( return dests, nil } +// parseProtocol reads the proto field of the ACL and generates a list of +// protocols that will be allowed, following the IANA IP protocol number +// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml +// +// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP, +// as per Tailscale behaviour (see tailcfg.FilterRule). +// +// Also returns a boolean indicating if the protocol +// requires all the destinations to use wildcard as port number (only TCP, +// UDP and SCTP support specifying ports). +func parseProtocol(protocol string) ([]int, bool, error) { + switch protocol { + case "": + return []int{1, 58, 6, 17}, false, nil + case "igmp": + return []int{2}, true, nil + case "ipv4", "ip-in-ip": + return []int{4}, true, nil + case "tcp": + return []int{6}, false, nil + case "egp": + return []int{8}, true, nil + case "igp": + return []int{9}, true, nil + case "udp": + return []int{17}, false, nil + case "gre": + return []int{47}, true, nil + case "esp": + return []int{50}, true, nil + case "ah": + return []int{51}, true, nil + case "sctp": + return []int{132}, false, nil + case "icmp": + return []int{1, 58}, true, nil + + default: + protocolNumber, err := strconv.Atoi(protocol) + if err != nil { + return nil, false, err + } + needsWildcard := protocolNumber != 6 && protocolNumber != 17 && protocolNumber != 132 + + return []int{protocolNumber}, needsWildcard, nil + } +} + // expandalias has an input of either // - a namespace // - a group @@ -268,6 +327,7 @@ func expandAlias( alias, ) } + return ips, nil } else { return ips, err @@ -359,13 +419,17 @@ func excludeCorrectlyTaggedNodes( return out } -func expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { +func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, error) { if portsStr == "*" { return &[]tailcfg.PortRange{ {First: portRangeBegin, Last: portRangeEnd}, }, nil } + if needsWildcard { + return nil, errWildcardIsNeeded + } + ports := []tailcfg.PortRange{} for _, portStr := range strings.Split(portsStr, ",") { rang := strings.Split(portStr, "-") diff --git a/acls_test.go b/acls_test.go index 9a7d8a6..eaf578b 100644 --- a/acls_test.go +++ b/acls_test.go @@ -628,7 +628,8 @@ func Test_expandTagOwners(t *testing.T) { func Test_expandPorts(t *testing.T) { type args struct { - portsStr string + portsStr string + needsWildcard bool } tests := []struct { name string @@ -638,15 +639,29 @@ func Test_expandPorts(t *testing.T) { }{ { name: "wildcard", - args: args{portsStr: "*"}, + args: args{portsStr: "*", needsWildcard: true}, want: &[]tailcfg.PortRange{ {First: portRangeBegin, Last: portRangeEnd}, }, wantErr: false, }, + { + name: "needs wildcard but does not require it", + args: args{portsStr: "*", needsWildcard: false}, + want: &[]tailcfg.PortRange{ + {First: portRangeBegin, Last: portRangeEnd}, + }, + wantErr: false, + }, + { + name: "needs wildcard but gets port", + args: args{portsStr: "80,443", needsWildcard: true}, + want: nil, + wantErr: true, + }, { name: "two Destinations", - args: args{portsStr: "80,443"}, + args: args{portsStr: "80,443", needsWildcard: false}, want: &[]tailcfg.PortRange{ {First: 80, Last: 80}, {First: 443, Last: 443}, @@ -655,7 +670,7 @@ func Test_expandPorts(t *testing.T) { }, { name: "a range and a port", - args: args{portsStr: "80-1024,443"}, + args: args{portsStr: "80-1024,443", needsWildcard: false}, want: &[]tailcfg.PortRange{ {First: 80, Last: 1024}, {First: 443, Last: 443}, @@ -664,38 +679,38 @@ func Test_expandPorts(t *testing.T) { }, { name: "out of bounds", - args: args{portsStr: "854038"}, + args: args{portsStr: "854038", needsWildcard: false}, want: nil, wantErr: true, }, { name: "wrong port", - args: args{portsStr: "85a38"}, + args: args{portsStr: "85a38", needsWildcard: false}, want: nil, wantErr: true, }, { name: "wrong port in first", - args: args{portsStr: "a-80"}, + args: args{portsStr: "a-80", needsWildcard: false}, want: nil, wantErr: true, }, { name: "wrong port in last", - args: args{portsStr: "80-85a38"}, + args: args{portsStr: "80-85a38", needsWildcard: false}, want: nil, wantErr: true, }, { name: "wrong port format", - args: args{portsStr: "80-85a38-3"}, + args: args{portsStr: "80-85a38-3", needsWildcard: false}, want: nil, wantErr: true, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := expandPorts(test.args.portsStr) + got, err := expandPorts(test.args.portsStr, test.args.needsWildcard) if (err != nil) != test.wantErr { t.Errorf("expandPorts() error = %v, wantErr %v", err, test.wantErr) diff --git a/acls_types.go b/acls_types.go index 6434509..1c952e2 100644 --- a/acls_types.go +++ b/acls_types.go @@ -21,7 +21,7 @@ type ACLPolicy struct { // ACL is a basic rule for the ACL Policy. type ACL struct { Action string `json:"action" yaml:"action"` - Protocol string `json:"protocol" yaml:"protocol"` + Protocol string `json:"proto" yaml:"proto"` Sources []string `json:"src" yaml:"src"` Destinations []string `json:"dst" yaml:"dst"` }