feat(acls): rewrite functions to be testable

Rewrite some function to get rid of the dependency on Headscale object. This allows us
to write succinct test that are more easy to review and implement.

The improvements of the tests allowed to write the removal of the tagged hosts
from the namespace as specified here: https://tailscale.com/kb/1068/acl-tags/
This commit is contained in:
Adrien Raffin 2022-02-07 16:12:05 +01:00 committed by Adrien Raffin-Caboisse
parent 97eac3b938
commit de59946447
No known key found for this signature in database
GPG key ID: 7FB60532DEBEAD6A
3 changed files with 646 additions and 75 deletions

187
acls.go
View file

@ -2,7 +2,6 @@ package headscale
import (
"encoding/json"
"errors"
"fmt"
"io"
"os"
@ -86,6 +85,11 @@ func (h *Headscale) UpdateACLRules() error {
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
machines, err := h.ListAllMachines()
if err != nil {
return nil, err
}
for index, acl := range h.aclPolicy.ACLs {
if acl.Action != "accept" {
return nil, errInvalidAction
@ -93,7 +97,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
srcIPs := []string{}
for innerIndex, user := range acl.Users {
srcs, err := h.generateACLPolicySrcIP(user)
srcs, err := h.generateACLPolicySrcIP(machines, *h.aclPolicy, user)
if err != nil {
log.Error().
Msgf("Error parsing ACL %d, User %d", index, innerIndex)
@ -105,7 +109,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
destPorts := []tailcfg.NetPortRange{}
for innerIndex, ports := range acl.Ports {
dests, err := h.generateACLPolicyDestPorts(ports)
dests, err := h.generateACLPolicyDestPorts(machines, *h.aclPolicy, ports)
if err != nil {
log.Error().
Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
@ -124,11 +128,13 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
return rules, nil
}
func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) {
return h.expandAlias(u)
func (h *Headscale) generateACLPolicySrcIP(machines []Machine, aclPolicy ACLPolicy, u string) ([]string, error) {
return expandAlias(machines, aclPolicy, u)
}
func (h *Headscale) generateACLPolicyDestPorts(
machines []Machine,
aclPolicy ACLPolicy,
d string,
) ([]tailcfg.NetPortRange, error) {
tokens := strings.Split(d, ":")
@ -149,11 +155,11 @@ func (h *Headscale) generateACLPolicyDestPorts(
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
}
expanded, err := h.expandAlias(alias)
expanded, err := expandAlias(machines, aclPolicy, alias)
if err != nil {
return nil, err
}
ports, err := h.expandPorts(tokens[len(tokens)-1])
ports, err := expandPorts(tokens[len(tokens)-1])
if err != nil {
return nil, err
}
@ -177,52 +183,40 @@ func (h *Headscale) generateACLPolicyDestPorts(
// - a group
// - a tag
// and transform these in IPAddresses
func (h *Headscale) expandAlias(alias string) ([]string, error) {
func expandAlias(machines []Machine, aclPolicy ACLPolicy, alias string) ([]string, error) {
ips := []string{}
if alias == "*" {
return []string{"*"}, nil
}
if strings.HasPrefix(alias, "group:") {
namespaces, err := h.expandGroup(alias)
namespaces, err := expandGroup(aclPolicy, alias)
if err != nil {
return nil, err
return ips, err
}
ips := []string{}
for _, n := range namespaces {
nodes, err := h.ListMachinesInNamespace(n)
if err != nil {
return nil, errInvalidNamespace
}
nodes := listMachinesInNamespace(machines, n)
for _, node := range nodes {
ips = append(ips, node.IPAddresses.ToStringSlice()...)
}
}
return ips, nil
}
if strings.HasPrefix(alias, "tag:") {
var ips []string
owners, err := h.expandTagOwners(alias)
owners, err := expandTagOwners(aclPolicy, alias)
if err != nil {
return nil, err
return ips, err
}
for _, namespace := range owners {
machines, err := h.ListMachinesInNamespace(namespace)
if err != nil {
if errors.Is(err, errNamespaceNotFound) {
continue
} else {
return nil, err
}
}
machines := listMachinesInNamespace(machines, namespace)
for _, machine := range machines {
if len(machine.HostInfo) == 0 {
continue
}
hi, err := machine.GetHostInfo()
if err != nil {
return nil, err
return ips, err
}
for _, t := range hi.RequestTags {
if alias == t {
@ -234,75 +228,75 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) {
return ips, nil
}
n, err := h.GetNamespace(alias)
if err == nil {
nodes, err := h.ListMachinesInNamespace(n.Name)
if err != nil {
return nil, err
}
ips := []string{}
for _, n := range nodes {
ips = append(ips, n.IPAddresses.ToStringSlice()...)
}
// if alias is a namespace
nodes := listMachinesInNamespace(machines, alias)
nodes, err := excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias)
if err != nil {
return ips, err
}
for _, n := range nodes {
ips = append(ips, n.IPAddresses.ToStringSlice()...)
}
if len(ips) > 0 {
return ips, nil
}
if h, ok := h.aclPolicy.Hosts[alias]; ok {
// if alias is an host
if h, ok := aclPolicy.Hosts[alias]; ok {
return []string{h.String()}, nil
}
// if alias is an IP
ip, err := netaddr.ParseIP(alias)
if err == nil {
return []string{ip.String()}, nil
}
// if alias is an CIDR
cidr, err := netaddr.ParseIPPrefix(alias)
if err == nil {
return []string{cidr.String()}, nil
}
return nil, errInvalidUserSection
return ips, errInvalidUserSection
}
// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group
// a group cannot be composed of groups
func (h *Headscale) expandTagOwners(owner string) ([]string, error) {
var owners []string
ows, ok := h.aclPolicy.TagOwners[owner]
if !ok {
return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, owner)
// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
// that are correctly tagged since they should not be listed as being in the namespace
// we assume in this function that we only have nodes from 1 namespace.
func excludeCorrectlyTaggedNodes(aclPolicy ACLPolicy, nodes []Machine, namespace string) ([]Machine, error) {
out := []Machine{}
tags := []string{}
for tag, ns := range aclPolicy.TagOwners {
if containsString(ns, namespace) {
tags = append(tags, tag)
}
}
for _, ow := range ows {
if strings.HasPrefix(ow, "group:") {
gs, err := h.expandGroup(ow)
if err != nil {
return []string{}, err
// for each machine if tag is in tags list, don't append it.
for _, machine := range nodes {
if len(machine.HostInfo) == 0 {
out = append(out, machine)
continue
}
hi, err := machine.GetHostInfo()
if err != nil {
return out, err
}
found := false
for _, t := range hi.RequestTags {
if containsString(tags, t) {
found = true
break
}
owners = append(owners, gs...)
} else {
owners = append(owners, ow)
}
if !found {
out = append(out, machine)
}
}
return owners, nil
return out, nil
}
// expandGroup will return the list of namespace inside the group
// after some validation
func (h *Headscale) expandGroup(group string) ([]string, error) {
gs, ok := h.aclPolicy.Groups[group]
if !ok {
return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup)
}
for _, g := range gs {
if strings.HasPrefix(g, "group:") {
return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup)
}
}
return gs, nil
}
func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
func expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
if portsStr == "*" {
return &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
@ -344,3 +338,50 @@ func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
return &ports, nil
}
func listMachinesInNamespace(machines []Machine, namespace string) []Machine {
out := []Machine{}
for _, machine := range machines {
if machine.Namespace.Name == namespace {
out = append(out, machine)
}
}
return out
}
// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group
// a group cannot be composed of groups
func expandTagOwners(aclPolicy ACLPolicy, tag string) ([]string, error) {
var owners []string
ows, ok := aclPolicy.TagOwners[tag]
if !ok {
return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, tag)
}
for _, ow := range ows {
if strings.HasPrefix(ow, "group:") {
gs, err := expandGroup(aclPolicy, ow)
if err != nil {
return []string{}, err
}
owners = append(owners, gs...)
} else {
owners = append(owners, ow)
}
}
return owners, nil
}
// expandGroup will return the list of namespace inside the group
// after some validation
func expandGroup(aclPolicy ACLPolicy, group string) ([]string, error) {
gs, ok := aclPolicy.Groups[group]
if !ok {
return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup)
}
for _, g := range gs {
if strings.HasPrefix(g, "group:") {
return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup)
}
}
return gs, nil
}

View file

@ -2,10 +2,13 @@ package headscale
import (
"errors"
"reflect"
"testing"
"gopkg.in/check.v1"
"gorm.io/datatypes"
"inet.af/netaddr"
"tailscale.com/tailcfg"
)
func (s *Suite) TestWrongPath(c *check.C) {
@ -267,9 +270,16 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) {
}
err = app.UpdateACLRules()
c.Assert(err, check.IsNil)
c.Logf("Rules: %v", app.aclRules)
c.Assert(app.aclRules, check.HasLen, 1)
c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 0)
c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1)
c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.2")
c.Assert(app.aclRules[0].DstPorts, check.HasLen, 2)
c.Assert(app.aclRules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
c.Assert(app.aclRules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1")
c.Assert(app.aclRules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
c.Assert(app.aclRules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
c.Assert(app.aclRules[0].DstPorts[1].IP, check.Equals, "100.64.0.1")
}
func (s *Suite) TestPortRange(c *check.C) {
@ -385,3 +395,510 @@ func (s *Suite) TestPortGroup(c *check.C) {
c.Assert(len(ips), check.Equals, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String())
}
func Test_expandGroup(t *testing.T) {
type args struct {
aclPolicy ACLPolicy
group string
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
name: "simple test",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}},
},
group: "group:test",
},
want: []string{"g1", "foo", "test"},
wantErr: false,
},
{
name: "InexistantGroup",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}},
},
group: "group:bar",
},
want: []string{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := expandGroup(tt.args.aclPolicy, tt.args.group)
if (err != nil) != tt.wantErr {
t.Errorf("expandGroup() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("expandGroup() = %v, want %v", got, tt.want)
}
})
}
}
func Test_expandTagOwners(t *testing.T) {
type args struct {
aclPolicy ACLPolicy
tag string
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
name: "simple tag",
args: args{
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:test": []string{"namespace1"}},
},
tag: "tag:test",
},
want: []string{"namespace1"},
wantErr: false,
},
{
name: "tag and group",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{"group:foo": []string{"n1", "bar"}},
TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
},
tag: "tag:test",
},
want: []string{"n1", "bar"},
wantErr: false,
},
{
name: "namespace and group",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{"group:foo": []string{"n1", "bar"}},
TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}},
},
tag: "tag:test",
},
want: []string{"n1", "bar", "home"},
wantErr: false,
},
{
name: "invalid tag",
args: args{
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:foo": []string{"group:foo", "home"}},
},
tag: "tag:test",
},
want: []string{},
wantErr: true,
},
{
name: "invalid group",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{"group:bar": []string{"n1", "foo"}},
TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}},
},
tag: "tag:test",
},
want: []string{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := expandTagOwners(tt.args.aclPolicy, tt.args.tag)
if (err != nil) != tt.wantErr {
t.Errorf("expandTagOwners() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("expandTagOwners() = %v, want %v", got, tt.want)
}
})
}
}
func Test_expandPorts(t *testing.T) {
type args struct {
portsStr string
}
tests := []struct {
name string
args args
want *[]tailcfg.PortRange
wantErr bool
}{
{
name: "wildcard",
args: args{portsStr: "*"},
want: &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
},
wantErr: false,
},
{
name: "two ports",
args: args{portsStr: "80,443"},
want: &[]tailcfg.PortRange{
{First: 80, Last: 80},
{First: 443, Last: 443},
},
wantErr: false,
},
{
name: "a range and a port",
args: args{portsStr: "80-1024,443"},
want: &[]tailcfg.PortRange{
{First: 80, Last: 1024},
{First: 443, Last: 443},
},
wantErr: false,
},
{
name: "out of bounds",
args: args{portsStr: "854038"},
want: nil,
wantErr: true,
},
{
name: "wrong port",
args: args{portsStr: "85a38"},
want: nil,
wantErr: true,
},
{
name: "wrong port in first",
args: args{portsStr: "a-80"},
want: nil,
wantErr: true,
},
{
name: "wrong port in last",
args: args{portsStr: "80-85a38"},
want: nil,
wantErr: true,
},
{
name: "wrong port format",
args: args{portsStr: "80-85a38-3"},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := expandPorts(tt.args.portsStr)
if (err != nil) != tt.wantErr {
t.Errorf("expandPorts() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("expandPorts() = %v, want %v", got, tt.want)
}
})
}
}
func Test_listMachinesInNamespace(t *testing.T) {
type args struct {
machines []Machine
namespace string
}
tests := []struct {
name string
args args
want []Machine
}{
{
name: "1 machine in namespace",
args: args{
machines: []Machine{
{Namespace: Namespace{Name: "test"}},
},
namespace: "test",
},
want: []Machine{
{Namespace: Namespace{Name: "test"}},
},
},
{
name: "3 machines, 2 in namespace",
args: args{
machines: []Machine{
{ID: 1, Namespace: Namespace{Name: "test"}},
{ID: 2, Namespace: Namespace{Name: "foo"}},
{ID: 3, Namespace: Namespace{Name: "foo"}},
},
namespace: "foo",
},
want: []Machine{
{ID: 2, Namespace: Namespace{Name: "foo"}},
{ID: 3, Namespace: Namespace{Name: "foo"}},
},
},
{
name: "5 machines, 0 in namespace",
args: args{
machines: []Machine{
{ID: 1, Namespace: Namespace{Name: "test"}},
{ID: 2, Namespace: Namespace{Name: "foo"}},
{ID: 3, Namespace: Namespace{Name: "foo"}},
{ID: 4, Namespace: Namespace{Name: "foo"}},
{ID: 5, Namespace: Namespace{Name: "foo"}},
},
namespace: "bar",
},
want: []Machine{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := listMachinesInNamespace(tt.args.machines, tt.args.namespace); !reflect.DeepEqual(got, tt.want) {
t.Errorf("listMachinesInNamespace() = %v, want %v", got, tt.want)
}
})
}
}
func Test_expandAlias(t *testing.T) {
type args struct {
machines []Machine
aclPolicy ACLPolicy
alias string
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
name: "wildcard",
args: args{
alias: "*",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.78.84.227")}},
},
aclPolicy: ACLPolicy{},
},
want: []string{"*"},
wantErr: false,
},
{
name: "simple group",
args: args{
alias: "group:foo",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
},
aclPolicy: ACLPolicy{
Groups: Groups{"group:foo": []string{"foo", "bar"}},
},
},
want: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"},
wantErr: false,
},
{
name: "wrong group",
args: args{
alias: "group:test",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
},
aclPolicy: ACLPolicy{
Groups: Groups{"group:foo": []string{"foo", "bar"}},
},
},
want: []string{},
wantErr: true,
},
{
name: "simple ipaddress",
args: args{
alias: "10.0.0.3",
machines: []Machine{},
aclPolicy: ACLPolicy{},
},
want: []string{"10.0.0.3"},
wantErr: false,
},
{
name: "private network",
args: args{
alias: "homeNetwork",
machines: []Machine{},
aclPolicy: ACLPolicy{
Hosts: Hosts{"homeNetwork": netaddr.MustParseIPPrefix("192.168.1.0/24")},
},
},
want: []string{"192.168.1.0/24"},
wantErr: false,
},
{
name: "simple host",
args: args{
alias: "10.0.0.1",
machines: []Machine{},
aclPolicy: ACLPolicy{},
},
want: []string{"10.0.0.1"},
wantErr: false,
},
{
name: "simple CIDR",
args: args{
alias: "10.0.0.0/16",
machines: []Machine{},
aclPolicy: ACLPolicy{},
},
want: []string{"10.0.0.0/16"},
wantErr: false,
},
{
name: "simple tag",
args: args{
alias: "tag:test",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:test": []string{"foo"}},
},
},
want: []string{"100.64.0.1", "100.64.0.2"},
wantErr: false,
},
{
name: "No tag defined",
args: args{
alias: "tag:foo",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
},
aclPolicy: ACLPolicy{
Groups: Groups{"group:foo": []string{"foo", "bar"}},
TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
},
},
want: []string{},
wantErr: true,
},
{
name: "list host in namespace without correctly tagged servers",
args: args{
alias: "foo",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:test": []string{"foo"}},
},
},
want: []string{"100.64.0.4"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := expandAlias(tt.args.machines, tt.args.aclPolicy, tt.args.alias)
if (err != nil) != tt.wantErr {
t.Errorf("expandAlias() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("expandAlias() = %v, want %v", got, tt.want)
}
})
}
}
func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
type args struct {
aclPolicy ACLPolicy
nodes []Machine
namespace string
}
tests := []struct {
name string
args args
want []Machine
wantErr bool
}{
{
name: "exclude nodes with valid tags",
args: args{
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:test": []string{"foo"}},
},
nodes: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
namespace: "foo",
},
want: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
wantErr: false,
},
{
name: "all nodes have invalid tags, don't exclude them",
args: args{
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:foo": []string{"foo"}},
},
nodes: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
namespace: "foo",
},
want: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := excludeCorrectlyTaggedNodes(tt.args.aclPolicy, tt.args.nodes, tt.args.namespace)
if (err != nil) != tt.wantErr {
t.Errorf("excludeCorrectlyTaggedNodes() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -119,6 +119,19 @@ func (machine Machine) isExpired() bool {
return time.Now().UTC().After(*machine.Expiry)
}
func (h *Headscale) ListAllMachines() ([]Machine, error) {
machines := []Machine{}
if err := h.db.Preload("AuthKey").
Preload("AuthKey.Namespace").
Preload("Namespace").
Where("registered").
Find(&machines).Error; err != nil {
return nil, err
}
return machines, nil
}
func containsAddresses(inputs []string, addrs MachineAddresses) bool {
for _, addr := range addrs.ToStringSlice() {
if containsString(inputs, addr) {