feat(acls): rewrite functions to be testable
Rewrite some function to get rid of the dependency on Headscale object. This allows us to write succinct test that are more easy to review and implement. The improvements of the tests allowed to write the removal of the tagged hosts from the namespace as specified here: https://tailscale.com/kb/1068/acl-tags/
This commit is contained in:
parent
97eac3b938
commit
de59946447
3 changed files with 646 additions and 75 deletions
175
acls.go
175
acls.go
|
@ -2,7 +2,6 @@ package headscale
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
@ -86,6 +85,11 @@ func (h *Headscale) UpdateACLRules() error {
|
||||||
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
||||||
rules := []tailcfg.FilterRule{}
|
rules := []tailcfg.FilterRule{}
|
||||||
|
|
||||||
|
machines, err := h.ListAllMachines()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
for index, acl := range h.aclPolicy.ACLs {
|
for index, acl := range h.aclPolicy.ACLs {
|
||||||
if acl.Action != "accept" {
|
if acl.Action != "accept" {
|
||||||
return nil, errInvalidAction
|
return nil, errInvalidAction
|
||||||
|
@ -93,7 +97,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
||||||
|
|
||||||
srcIPs := []string{}
|
srcIPs := []string{}
|
||||||
for innerIndex, user := range acl.Users {
|
for innerIndex, user := range acl.Users {
|
||||||
srcs, err := h.generateACLPolicySrcIP(user)
|
srcs, err := h.generateACLPolicySrcIP(machines, *h.aclPolicy, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msgf("Error parsing ACL %d, User %d", index, innerIndex)
|
Msgf("Error parsing ACL %d, User %d", index, innerIndex)
|
||||||
|
@ -105,7 +109,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
||||||
|
|
||||||
destPorts := []tailcfg.NetPortRange{}
|
destPorts := []tailcfg.NetPortRange{}
|
||||||
for innerIndex, ports := range acl.Ports {
|
for innerIndex, ports := range acl.Ports {
|
||||||
dests, err := h.generateACLPolicyDestPorts(ports)
|
dests, err := h.generateACLPolicyDestPorts(machines, *h.aclPolicy, ports)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
|
Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
|
||||||
|
@ -124,11 +128,13 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) {
|
func (h *Headscale) generateACLPolicySrcIP(machines []Machine, aclPolicy ACLPolicy, u string) ([]string, error) {
|
||||||
return h.expandAlias(u)
|
return expandAlias(machines, aclPolicy, u)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) generateACLPolicyDestPorts(
|
func (h *Headscale) generateACLPolicyDestPorts(
|
||||||
|
machines []Machine,
|
||||||
|
aclPolicy ACLPolicy,
|
||||||
d string,
|
d string,
|
||||||
) ([]tailcfg.NetPortRange, error) {
|
) ([]tailcfg.NetPortRange, error) {
|
||||||
tokens := strings.Split(d, ":")
|
tokens := strings.Split(d, ":")
|
||||||
|
@ -149,11 +155,11 @@ func (h *Headscale) generateACLPolicyDestPorts(
|
||||||
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
expanded, err := h.expandAlias(alias)
|
expanded, err := expandAlias(machines, aclPolicy, alias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ports, err := h.expandPorts(tokens[len(tokens)-1])
|
ports, err := expandPorts(tokens[len(tokens)-1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -177,52 +183,40 @@ func (h *Headscale) generateACLPolicyDestPorts(
|
||||||
// - a group
|
// - a group
|
||||||
// - a tag
|
// - a tag
|
||||||
// and transform these in IPAddresses
|
// and transform these in IPAddresses
|
||||||
func (h *Headscale) expandAlias(alias string) ([]string, error) {
|
func expandAlias(machines []Machine, aclPolicy ACLPolicy, alias string) ([]string, error) {
|
||||||
|
ips := []string{}
|
||||||
if alias == "*" {
|
if alias == "*" {
|
||||||
return []string{"*"}, nil
|
return []string{"*"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(alias, "group:") {
|
if strings.HasPrefix(alias, "group:") {
|
||||||
namespaces, err := h.expandGroup(alias)
|
namespaces, err := expandGroup(aclPolicy, alias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return ips, err
|
||||||
}
|
}
|
||||||
ips := []string{}
|
|
||||||
for _, n := range namespaces {
|
for _, n := range namespaces {
|
||||||
nodes, err := h.ListMachinesInNamespace(n)
|
nodes := listMachinesInNamespace(machines, n)
|
||||||
if err != nil {
|
|
||||||
return nil, errInvalidNamespace
|
|
||||||
}
|
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
ips = append(ips, node.IPAddresses.ToStringSlice()...)
|
ips = append(ips, node.IPAddresses.ToStringSlice()...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(alias, "tag:") {
|
if strings.HasPrefix(alias, "tag:") {
|
||||||
var ips []string
|
owners, err := expandTagOwners(aclPolicy, alias)
|
||||||
owners, err := h.expandTagOwners(alias)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return ips, err
|
||||||
}
|
}
|
||||||
for _, namespace := range owners {
|
for _, namespace := range owners {
|
||||||
machines, err := h.ListMachinesInNamespace(namespace)
|
machines := listMachinesInNamespace(machines, namespace)
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, errNamespaceNotFound) {
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, machine := range machines {
|
for _, machine := range machines {
|
||||||
if len(machine.HostInfo) == 0 {
|
if len(machine.HostInfo) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
hi, err := machine.GetHostInfo()
|
hi, err := machine.GetHostInfo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return ips, err
|
||||||
}
|
}
|
||||||
for _, t := range hi.RequestTags {
|
for _, t := range hi.RequestTags {
|
||||||
if alias == t {
|
if alias == t {
|
||||||
|
@ -234,75 +228,75 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) {
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := h.GetNamespace(alias)
|
// if alias is a namespace
|
||||||
if err == nil {
|
nodes := listMachinesInNamespace(machines, alias)
|
||||||
nodes, err := h.ListMachinesInNamespace(n.Name)
|
nodes, err := excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return ips, err
|
||||||
}
|
}
|
||||||
ips := []string{}
|
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
ips = append(ips, n.IPAddresses.ToStringSlice()...)
|
ips = append(ips, n.IPAddresses.ToStringSlice()...)
|
||||||
}
|
}
|
||||||
|
if len(ips) > 0 {
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if h, ok := h.aclPolicy.Hosts[alias]; ok {
|
// if alias is an host
|
||||||
|
if h, ok := aclPolicy.Hosts[alias]; ok {
|
||||||
return []string{h.String()}, nil
|
return []string{h.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if alias is an IP
|
||||||
ip, err := netaddr.ParseIP(alias)
|
ip, err := netaddr.ParseIP(alias)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return []string{ip.String()}, nil
|
return []string{ip.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if alias is an CIDR
|
||||||
cidr, err := netaddr.ParseIPPrefix(alias)
|
cidr, err := netaddr.ParseIPPrefix(alias)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return []string{cidr.String()}, nil
|
return []string{cidr.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errInvalidUserSection
|
return ips, errInvalidUserSection
|
||||||
}
|
}
|
||||||
|
|
||||||
// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group
|
// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
|
||||||
// a group cannot be composed of groups
|
// that are correctly tagged since they should not be listed as being in the namespace
|
||||||
func (h *Headscale) expandTagOwners(owner string) ([]string, error) {
|
// we assume in this function that we only have nodes from 1 namespace.
|
||||||
var owners []string
|
func excludeCorrectlyTaggedNodes(aclPolicy ACLPolicy, nodes []Machine, namespace string) ([]Machine, error) {
|
||||||
ows, ok := h.aclPolicy.TagOwners[owner]
|
out := []Machine{}
|
||||||
if !ok {
|
tags := []string{}
|
||||||
return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, owner)
|
for tag, ns := range aclPolicy.TagOwners {
|
||||||
|
if containsString(ns, namespace) {
|
||||||
|
tags = append(tags, tag)
|
||||||
}
|
}
|
||||||
for _, ow := range ows {
|
}
|
||||||
if strings.HasPrefix(ow, "group:") {
|
// for each machine if tag is in tags list, don't append it.
|
||||||
gs, err := h.expandGroup(ow)
|
for _, machine := range nodes {
|
||||||
|
if len(machine.HostInfo) == 0 {
|
||||||
|
out = append(out, machine)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hi, err := machine.GetHostInfo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []string{}, err
|
return out, err
|
||||||
}
|
}
|
||||||
owners = append(owners, gs...)
|
found := false
|
||||||
} else {
|
for _, t := range hi.RequestTags {
|
||||||
owners = append(owners, ow)
|
if containsString(tags, t) {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return owners, nil
|
if !found {
|
||||||
|
out = append(out, machine)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// expandGroup will return the list of namespace inside the group
|
func expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
|
||||||
// after some validation
|
|
||||||
func (h *Headscale) expandGroup(group string) ([]string, error) {
|
|
||||||
gs, ok := h.aclPolicy.Groups[group]
|
|
||||||
if !ok {
|
|
||||||
return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup)
|
|
||||||
}
|
|
||||||
for _, g := range gs {
|
|
||||||
if strings.HasPrefix(g, "group:") {
|
|
||||||
return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return gs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
|
|
||||||
if portsStr == "*" {
|
if portsStr == "*" {
|
||||||
return &[]tailcfg.PortRange{
|
return &[]tailcfg.PortRange{
|
||||||
{First: portRangeBegin, Last: portRangeEnd},
|
{First: portRangeBegin, Last: portRangeEnd},
|
||||||
|
@ -344,3 +338,50 @@ func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
|
||||||
|
|
||||||
return &ports, nil
|
return &ports, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func listMachinesInNamespace(machines []Machine, namespace string) []Machine {
|
||||||
|
out := []Machine{}
|
||||||
|
for _, machine := range machines {
|
||||||
|
if machine.Namespace.Name == namespace {
|
||||||
|
out = append(out, machine)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group
|
||||||
|
// a group cannot be composed of groups
|
||||||
|
func expandTagOwners(aclPolicy ACLPolicy, tag string) ([]string, error) {
|
||||||
|
var owners []string
|
||||||
|
ows, ok := aclPolicy.TagOwners[tag]
|
||||||
|
if !ok {
|
||||||
|
return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, tag)
|
||||||
|
}
|
||||||
|
for _, ow := range ows {
|
||||||
|
if strings.HasPrefix(ow, "group:") {
|
||||||
|
gs, err := expandGroup(aclPolicy, ow)
|
||||||
|
if err != nil {
|
||||||
|
return []string{}, err
|
||||||
|
}
|
||||||
|
owners = append(owners, gs...)
|
||||||
|
} else {
|
||||||
|
owners = append(owners, ow)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return owners, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// expandGroup will return the list of namespace inside the group
|
||||||
|
// after some validation
|
||||||
|
func expandGroup(aclPolicy ACLPolicy, group string) ([]string, error) {
|
||||||
|
gs, ok := aclPolicy.Groups[group]
|
||||||
|
if !ok {
|
||||||
|
return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup)
|
||||||
|
}
|
||||||
|
for _, g := range gs {
|
||||||
|
if strings.HasPrefix(g, "group:") {
|
||||||
|
return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return gs, nil
|
||||||
|
}
|
||||||
|
|
521
acls_test.go
521
acls_test.go
|
@ -2,10 +2,13 @@ package headscale
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"gorm.io/datatypes"
|
"gorm.io/datatypes"
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestWrongPath(c *check.C) {
|
func (s *Suite) TestWrongPath(c *check.C) {
|
||||||
|
@ -267,9 +270,16 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) {
|
||||||
}
|
}
|
||||||
err = app.UpdateACLRules()
|
err = app.UpdateACLRules()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Logf("Rules: %v", app.aclRules)
|
|
||||||
c.Assert(app.aclRules, check.HasLen, 1)
|
c.Assert(app.aclRules, check.HasLen, 1)
|
||||||
c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 0)
|
c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1)
|
||||||
|
c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.2")
|
||||||
|
c.Assert(app.aclRules[0].DstPorts, check.HasLen, 2)
|
||||||
|
c.Assert(app.aclRules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
|
||||||
|
c.Assert(app.aclRules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
|
||||||
|
c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1")
|
||||||
|
c.Assert(app.aclRules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
|
||||||
|
c.Assert(app.aclRules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
|
||||||
|
c.Assert(app.aclRules[0].DstPorts[1].IP, check.Equals, "100.64.0.1")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestPortRange(c *check.C) {
|
func (s *Suite) TestPortRange(c *check.C) {
|
||||||
|
@ -385,3 +395,510 @@ func (s *Suite) TestPortGroup(c *check.C) {
|
||||||
c.Assert(len(ips), check.Equals, 1)
|
c.Assert(len(ips), check.Equals, 1)
|
||||||
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String())
|
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_expandGroup(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
aclPolicy ACLPolicy
|
||||||
|
group string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple test",
|
||||||
|
args: args{
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}},
|
||||||
|
},
|
||||||
|
group: "group:test",
|
||||||
|
},
|
||||||
|
want: []string{"g1", "foo", "test"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "InexistantGroup",
|
||||||
|
args: args{
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}},
|
||||||
|
},
|
||||||
|
group: "group:bar",
|
||||||
|
},
|
||||||
|
want: []string{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := expandGroup(tt.args.aclPolicy, tt.args.group)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("expandGroup() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("expandGroup() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_expandTagOwners(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
aclPolicy ACLPolicy
|
||||||
|
tag string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple tag",
|
||||||
|
args: args{
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
TagOwners: TagOwners{"tag:test": []string{"namespace1"}},
|
||||||
|
},
|
||||||
|
tag: "tag:test",
|
||||||
|
},
|
||||||
|
want: []string{"namespace1"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tag and group",
|
||||||
|
args: args{
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
Groups: Groups{"group:foo": []string{"n1", "bar"}},
|
||||||
|
TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
|
||||||
|
},
|
||||||
|
tag: "tag:test",
|
||||||
|
},
|
||||||
|
want: []string{"n1", "bar"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "namespace and group",
|
||||||
|
args: args{
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
Groups: Groups{"group:foo": []string{"n1", "bar"}},
|
||||||
|
TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}},
|
||||||
|
},
|
||||||
|
tag: "tag:test",
|
||||||
|
},
|
||||||
|
want: []string{"n1", "bar", "home"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid tag",
|
||||||
|
args: args{
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
TagOwners: TagOwners{"tag:foo": []string{"group:foo", "home"}},
|
||||||
|
},
|
||||||
|
tag: "tag:test",
|
||||||
|
},
|
||||||
|
want: []string{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid group",
|
||||||
|
args: args{
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
Groups: Groups{"group:bar": []string{"n1", "foo"}},
|
||||||
|
TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}},
|
||||||
|
},
|
||||||
|
tag: "tag:test",
|
||||||
|
},
|
||||||
|
want: []string{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := expandTagOwners(tt.args.aclPolicy, tt.args.tag)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("expandTagOwners() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("expandTagOwners() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_expandPorts(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
portsStr string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *[]tailcfg.PortRange
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "wildcard",
|
||||||
|
args: args{portsStr: "*"},
|
||||||
|
want: &[]tailcfg.PortRange{
|
||||||
|
{First: portRangeBegin, Last: portRangeEnd},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two ports",
|
||||||
|
args: args{portsStr: "80,443"},
|
||||||
|
want: &[]tailcfg.PortRange{
|
||||||
|
{First: 80, Last: 80},
|
||||||
|
{First: 443, Last: 443},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "a range and a port",
|
||||||
|
args: args{portsStr: "80-1024,443"},
|
||||||
|
want: &[]tailcfg.PortRange{
|
||||||
|
{First: 80, Last: 1024},
|
||||||
|
{First: 443, Last: 443},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "out of bounds",
|
||||||
|
args: args{portsStr: "854038"},
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong port",
|
||||||
|
args: args{portsStr: "85a38"},
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong port in first",
|
||||||
|
args: args{portsStr: "a-80"},
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong port in last",
|
||||||
|
args: args{portsStr: "80-85a38"},
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong port format",
|
||||||
|
args: args{portsStr: "80-85a38-3"},
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := expandPorts(tt.args.portsStr)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("expandPorts() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("expandPorts() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_listMachinesInNamespace(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
machines []Machine
|
||||||
|
namespace string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []Machine
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "1 machine in namespace",
|
||||||
|
args: args{
|
||||||
|
machines: []Machine{
|
||||||
|
{Namespace: Namespace{Name: "test"}},
|
||||||
|
},
|
||||||
|
namespace: "test",
|
||||||
|
},
|
||||||
|
want: []Machine{
|
||||||
|
{Namespace: Namespace{Name: "test"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "3 machines, 2 in namespace",
|
||||||
|
args: args{
|
||||||
|
machines: []Machine{
|
||||||
|
{ID: 1, Namespace: Namespace{Name: "test"}},
|
||||||
|
{ID: 2, Namespace: Namespace{Name: "foo"}},
|
||||||
|
{ID: 3, Namespace: Namespace{Name: "foo"}},
|
||||||
|
},
|
||||||
|
namespace: "foo",
|
||||||
|
},
|
||||||
|
want: []Machine{
|
||||||
|
{ID: 2, Namespace: Namespace{Name: "foo"}},
|
||||||
|
{ID: 3, Namespace: Namespace{Name: "foo"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "5 machines, 0 in namespace",
|
||||||
|
args: args{
|
||||||
|
machines: []Machine{
|
||||||
|
{ID: 1, Namespace: Namespace{Name: "test"}},
|
||||||
|
{ID: 2, Namespace: Namespace{Name: "foo"}},
|
||||||
|
{ID: 3, Namespace: Namespace{Name: "foo"}},
|
||||||
|
{ID: 4, Namespace: Namespace{Name: "foo"}},
|
||||||
|
{ID: 5, Namespace: Namespace{Name: "foo"}},
|
||||||
|
},
|
||||||
|
namespace: "bar",
|
||||||
|
},
|
||||||
|
want: []Machine{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := listMachinesInNamespace(tt.args.machines, tt.args.namespace); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("listMachinesInNamespace() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_expandAlias(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
machines []Machine
|
||||||
|
aclPolicy ACLPolicy
|
||||||
|
alias string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "wildcard",
|
||||||
|
args: args{
|
||||||
|
alias: "*",
|
||||||
|
machines: []Machine{
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.78.84.227")}},
|
||||||
|
},
|
||||||
|
aclPolicy: ACLPolicy{},
|
||||||
|
},
|
||||||
|
want: []string{"*"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple group",
|
||||||
|
args: args{
|
||||||
|
alias: "group:foo",
|
||||||
|
machines: []Machine{
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
|
||||||
|
},
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
Groups: Groups{"group:foo": []string{"foo", "bar"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong group",
|
||||||
|
args: args{
|
||||||
|
alias: "group:test",
|
||||||
|
machines: []Machine{
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
|
||||||
|
},
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
Groups: Groups{"group:foo": []string{"foo", "bar"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []string{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple ipaddress",
|
||||||
|
args: args{
|
||||||
|
alias: "10.0.0.3",
|
||||||
|
machines: []Machine{},
|
||||||
|
aclPolicy: ACLPolicy{},
|
||||||
|
},
|
||||||
|
want: []string{"10.0.0.3"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "private network",
|
||||||
|
args: args{
|
||||||
|
alias: "homeNetwork",
|
||||||
|
machines: []Machine{},
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
Hosts: Hosts{"homeNetwork": netaddr.MustParseIPPrefix("192.168.1.0/24")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []string{"192.168.1.0/24"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple host",
|
||||||
|
args: args{
|
||||||
|
alias: "10.0.0.1",
|
||||||
|
machines: []Machine{},
|
||||||
|
aclPolicy: ACLPolicy{},
|
||||||
|
},
|
||||||
|
want: []string{"10.0.0.1"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple CIDR",
|
||||||
|
args: args{
|
||||||
|
alias: "10.0.0.0/16",
|
||||||
|
machines: []Machine{},
|
||||||
|
aclPolicy: ACLPolicy{},
|
||||||
|
},
|
||||||
|
want: []string{"10.0.0.0/16"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple tag",
|
||||||
|
args: args{
|
||||||
|
alias: "tag:test",
|
||||||
|
machines: []Machine{
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
},
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
TagOwners: TagOwners{"tag:test": []string{"foo"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []string{"100.64.0.1", "100.64.0.2"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No tag defined",
|
||||||
|
args: args{
|
||||||
|
alias: "tag:foo",
|
||||||
|
machines: []Machine{
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
|
||||||
|
},
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
Groups: Groups{"group:foo": []string{"foo", "bar"}},
|
||||||
|
TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []string{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "list host in namespace without correctly tagged servers",
|
||||||
|
args: args{
|
||||||
|
alias: "foo",
|
||||||
|
machines: []Machine{
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
},
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
TagOwners: TagOwners{"tag:test": []string{"foo"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []string{"100.64.0.4"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := expandAlias(tt.args.machines, tt.args.aclPolicy, tt.args.alias)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("expandAlias() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("expandAlias() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
aclPolicy ACLPolicy
|
||||||
|
nodes []Machine
|
||||||
|
namespace string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []Machine
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exclude nodes with valid tags",
|
||||||
|
args: args{
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
TagOwners: TagOwners{"tag:test": []string{"foo"}},
|
||||||
|
},
|
||||||
|
nodes: []Machine{
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
},
|
||||||
|
namespace: "foo",
|
||||||
|
},
|
||||||
|
want: []Machine{
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all nodes have invalid tags, don't exclude them",
|
||||||
|
args: args{
|
||||||
|
aclPolicy: ACLPolicy{
|
||||||
|
TagOwners: TagOwners{"tag:foo": []string{"foo"}},
|
||||||
|
},
|
||||||
|
nodes: []Machine{
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
},
|
||||||
|
namespace: "foo",
|
||||||
|
},
|
||||||
|
want: []Machine{
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||||
|
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := excludeCorrectlyTaggedNodes(tt.args.aclPolicy, tt.args.nodes, tt.args.namespace)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("excludeCorrectlyTaggedNodes() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
13
machine.go
13
machine.go
|
@ -119,6 +119,19 @@ func (machine Machine) isExpired() bool {
|
||||||
return time.Now().UTC().After(*machine.Expiry)
|
return time.Now().UTC().After(*machine.Expiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) ListAllMachines() ([]Machine, error) {
|
||||||
|
machines := []Machine{}
|
||||||
|
if err := h.db.Preload("AuthKey").
|
||||||
|
Preload("AuthKey.Namespace").
|
||||||
|
Preload("Namespace").
|
||||||
|
Where("registered").
|
||||||
|
Find(&machines).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return machines, nil
|
||||||
|
}
|
||||||
|
|
||||||
func containsAddresses(inputs []string, addrs MachineAddresses) bool {
|
func containsAddresses(inputs []string, addrs MachineAddresses) bool {
|
||||||
for _, addr := range addrs.ToStringSlice() {
|
for _, addr := range addrs.ToStringSlice() {
|
||||||
if containsString(inputs, addr) {
|
if containsString(inputs, addr) {
|
||||||
|
|
Loading…
Reference in a new issue