diff --git a/acls.go b/acls.go index c86e315..9dd1260 100644 --- a/acls.go +++ b/acls.go @@ -2,7 +2,6 @@ package headscale import ( "encoding/json" - "errors" "fmt" "io" "os" @@ -86,6 +85,11 @@ func (h *Headscale) UpdateACLRules() error { func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} + machines, err := h.ListAllMachines() + if err != nil { + return nil, err + } + for index, acl := range h.aclPolicy.ACLs { if acl.Action != "accept" { return nil, errInvalidAction @@ -93,7 +97,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { srcIPs := []string{} for innerIndex, user := range acl.Users { - srcs, err := h.generateACLPolicySrcIP(user) + srcs, err := h.generateACLPolicySrcIP(machines, *h.aclPolicy, user) if err != nil { log.Error(). Msgf("Error parsing ACL %d, User %d", index, innerIndex) @@ -105,7 +109,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { destPorts := []tailcfg.NetPortRange{} for innerIndex, ports := range acl.Ports { - dests, err := h.generateACLPolicyDestPorts(ports) + dests, err := h.generateACLPolicyDestPorts(machines, *h.aclPolicy, ports) if err != nil { log.Error(). Msgf("Error parsing ACL %d, Port %d", index, innerIndex) @@ -124,11 +128,13 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { return rules, nil } -func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) { - return h.expandAlias(u) +func (h *Headscale) generateACLPolicySrcIP(machines []Machine, aclPolicy ACLPolicy, u string) ([]string, error) { + return expandAlias(machines, aclPolicy, u) } func (h *Headscale) generateACLPolicyDestPorts( + machines []Machine, + aclPolicy ACLPolicy, d string, ) ([]tailcfg.NetPortRange, error) { tokens := strings.Split(d, ":") @@ -149,11 +155,11 @@ func (h *Headscale) generateACLPolicyDestPorts( alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) } - expanded, err := h.expandAlias(alias) + expanded, err := expandAlias(machines, aclPolicy, alias) if err != nil { return nil, err } - ports, err := h.expandPorts(tokens[len(tokens)-1]) + ports, err := expandPorts(tokens[len(tokens)-1]) if err != nil { return nil, err } @@ -177,52 +183,40 @@ func (h *Headscale) generateACLPolicyDestPorts( // - a group // - a tag // and transform these in IPAddresses -func (h *Headscale) expandAlias(alias string) ([]string, error) { +func expandAlias(machines []Machine, aclPolicy ACLPolicy, alias string) ([]string, error) { + ips := []string{} if alias == "*" { return []string{"*"}, nil } if strings.HasPrefix(alias, "group:") { - namespaces, err := h.expandGroup(alias) + namespaces, err := expandGroup(aclPolicy, alias) if err != nil { - return nil, err + return ips, err } - ips := []string{} for _, n := range namespaces { - nodes, err := h.ListMachinesInNamespace(n) - if err != nil { - return nil, errInvalidNamespace - } + nodes := listMachinesInNamespace(machines, n) for _, node := range nodes { ips = append(ips, node.IPAddresses.ToStringSlice()...) } } - return ips, nil } if strings.HasPrefix(alias, "tag:") { - var ips []string - owners, err := h.expandTagOwners(alias) + owners, err := expandTagOwners(aclPolicy, alias) if err != nil { - return nil, err + return ips, err } for _, namespace := range owners { - machines, err := h.ListMachinesInNamespace(namespace) - if err != nil { - if errors.Is(err, errNamespaceNotFound) { - continue - } else { - return nil, err - } - } + machines := listMachinesInNamespace(machines, namespace) for _, machine := range machines { if len(machine.HostInfo) == 0 { continue } hi, err := machine.GetHostInfo() if err != nil { - return nil, err + return ips, err } for _, t := range hi.RequestTags { if alias == t { @@ -234,75 +228,75 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) { return ips, nil } - n, err := h.GetNamespace(alias) - if err == nil { - nodes, err := h.ListMachinesInNamespace(n.Name) - if err != nil { - return nil, err - } - ips := []string{} - for _, n := range nodes { - ips = append(ips, n.IPAddresses.ToStringSlice()...) - } - + // if alias is a namespace + nodes := listMachinesInNamespace(machines, alias) + nodes, err := excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias) + if err != nil { + return ips, err + } + for _, n := range nodes { + ips = append(ips, n.IPAddresses.ToStringSlice()...) + } + if len(ips) > 0 { return ips, nil } - if h, ok := h.aclPolicy.Hosts[alias]; ok { + // if alias is an host + if h, ok := aclPolicy.Hosts[alias]; ok { return []string{h.String()}, nil } + // if alias is an IP ip, err := netaddr.ParseIP(alias) if err == nil { return []string{ip.String()}, nil } + // if alias is an CIDR cidr, err := netaddr.ParseIPPrefix(alias) if err == nil { return []string{cidr.String()}, nil } - return nil, errInvalidUserSection + return ips, errInvalidUserSection } -// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group -// a group cannot be composed of groups -func (h *Headscale) expandTagOwners(owner string) ([]string, error) { - var owners []string - ows, ok := h.aclPolicy.TagOwners[owner] - if !ok { - return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, owner) +// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones +// that are correctly tagged since they should not be listed as being in the namespace +// we assume in this function that we only have nodes from 1 namespace. +func excludeCorrectlyTaggedNodes(aclPolicy ACLPolicy, nodes []Machine, namespace string) ([]Machine, error) { + out := []Machine{} + tags := []string{} + for tag, ns := range aclPolicy.TagOwners { + if containsString(ns, namespace) { + tags = append(tags, tag) + } } - for _, ow := range ows { - if strings.HasPrefix(ow, "group:") { - gs, err := h.expandGroup(ow) - if err != nil { - return []string{}, err + // for each machine if tag is in tags list, don't append it. + for _, machine := range nodes { + if len(machine.HostInfo) == 0 { + out = append(out, machine) + continue + } + hi, err := machine.GetHostInfo() + if err != nil { + return out, err + } + found := false + for _, t := range hi.RequestTags { + if containsString(tags, t) { + found = true + break } - owners = append(owners, gs...) - } else { - owners = append(owners, ow) + } + if !found { + out = append(out, machine) } } - return owners, nil + return out, nil } -// expandGroup will return the list of namespace inside the group -// after some validation -func (h *Headscale) expandGroup(group string) ([]string, error) { - gs, ok := h.aclPolicy.Groups[group] - if !ok { - return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup) - } - for _, g := range gs { - if strings.HasPrefix(g, "group:") { - return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup) - } - } - return gs, nil -} - -func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { +func expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { if portsStr == "*" { return &[]tailcfg.PortRange{ {First: portRangeBegin, Last: portRangeEnd}, @@ -344,3 +338,50 @@ func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { return &ports, nil } + +func listMachinesInNamespace(machines []Machine, namespace string) []Machine { + out := []Machine{} + for _, machine := range machines { + if machine.Namespace.Name == namespace { + out = append(out, machine) + } + } + return out +} + +// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group +// a group cannot be composed of groups +func expandTagOwners(aclPolicy ACLPolicy, tag string) ([]string, error) { + var owners []string + ows, ok := aclPolicy.TagOwners[tag] + if !ok { + return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, tag) + } + for _, ow := range ows { + if strings.HasPrefix(ow, "group:") { + gs, err := expandGroup(aclPolicy, ow) + if err != nil { + return []string{}, err + } + owners = append(owners, gs...) + } else { + owners = append(owners, ow) + } + } + return owners, nil +} + +// expandGroup will return the list of namespace inside the group +// after some validation +func expandGroup(aclPolicy ACLPolicy, group string) ([]string, error) { + gs, ok := aclPolicy.Groups[group] + if !ok { + return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup) + } + for _, g := range gs { + if strings.HasPrefix(g, "group:") { + return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup) + } + } + return gs, nil +} diff --git a/acls_test.go b/acls_test.go index f2fb6a0..8bedb47 100644 --- a/acls_test.go +++ b/acls_test.go @@ -2,10 +2,13 @@ package headscale import ( "errors" + "reflect" + "testing" "gopkg.in/check.v1" "gorm.io/datatypes" "inet.af/netaddr" + "tailscale.com/tailcfg" ) func (s *Suite) TestWrongPath(c *check.C) { @@ -267,9 +270,16 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) { } err = app.UpdateACLRules() c.Assert(err, check.IsNil) - c.Logf("Rules: %v", app.aclRules) c.Assert(app.aclRules, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 0) + c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1) + c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.2") + c.Assert(app.aclRules[0].DstPorts, check.HasLen, 2) + c.Assert(app.aclRules[0].DstPorts[0].Ports.First, check.Equals, uint16(80)) + c.Assert(app.aclRules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80)) + c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1") + c.Assert(app.aclRules[0].DstPorts[1].Ports.First, check.Equals, uint16(443)) + c.Assert(app.aclRules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443)) + c.Assert(app.aclRules[0].DstPorts[1].IP, check.Equals, "100.64.0.1") } func (s *Suite) TestPortRange(c *check.C) { @@ -385,3 +395,510 @@ func (s *Suite) TestPortGroup(c *check.C) { c.Assert(len(ips), check.Equals, 1) c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()) } + +func Test_expandGroup(t *testing.T) { + type args struct { + aclPolicy ACLPolicy + group string + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + { + name: "simple test", + args: args{ + aclPolicy: ACLPolicy{ + Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}}, + }, + group: "group:test", + }, + want: []string{"g1", "foo", "test"}, + wantErr: false, + }, + { + name: "InexistantGroup", + args: args{ + aclPolicy: ACLPolicy{ + Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}}, + }, + group: "group:bar", + }, + want: []string{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := expandGroup(tt.args.aclPolicy, tt.args.group) + if (err != nil) != tt.wantErr { + t.Errorf("expandGroup() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("expandGroup() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_expandTagOwners(t *testing.T) { + type args struct { + aclPolicy ACLPolicy + tag string + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + { + name: "simple tag", + args: args{ + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:test": []string{"namespace1"}}, + }, + tag: "tag:test", + }, + want: []string{"namespace1"}, + wantErr: false, + }, + { + name: "tag and group", + args: args{ + aclPolicy: ACLPolicy{ + Groups: Groups{"group:foo": []string{"n1", "bar"}}, + TagOwners: TagOwners{"tag:test": []string{"group:foo"}}, + }, + tag: "tag:test", + }, + want: []string{"n1", "bar"}, + wantErr: false, + }, + { + name: "namespace and group", + args: args{ + aclPolicy: ACLPolicy{ + Groups: Groups{"group:foo": []string{"n1", "bar"}}, + TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}}, + }, + tag: "tag:test", + }, + want: []string{"n1", "bar", "home"}, + wantErr: false, + }, + { + name: "invalid tag", + args: args{ + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:foo": []string{"group:foo", "home"}}, + }, + tag: "tag:test", + }, + want: []string{}, + wantErr: true, + }, + { + name: "invalid group", + args: args{ + aclPolicy: ACLPolicy{ + Groups: Groups{"group:bar": []string{"n1", "foo"}}, + TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}}, + }, + tag: "tag:test", + }, + want: []string{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := expandTagOwners(tt.args.aclPolicy, tt.args.tag) + if (err != nil) != tt.wantErr { + t.Errorf("expandTagOwners() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("expandTagOwners() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_expandPorts(t *testing.T) { + type args struct { + portsStr string + } + tests := []struct { + name string + args args + want *[]tailcfg.PortRange + wantErr bool + }{ + { + name: "wildcard", + args: args{portsStr: "*"}, + want: &[]tailcfg.PortRange{ + {First: portRangeBegin, Last: portRangeEnd}, + }, + wantErr: false, + }, + { + name: "two ports", + args: args{portsStr: "80,443"}, + want: &[]tailcfg.PortRange{ + {First: 80, Last: 80}, + {First: 443, Last: 443}, + }, + wantErr: false, + }, + { + name: "a range and a port", + args: args{portsStr: "80-1024,443"}, + want: &[]tailcfg.PortRange{ + {First: 80, Last: 1024}, + {First: 443, Last: 443}, + }, + wantErr: false, + }, + { + name: "out of bounds", + args: args{portsStr: "854038"}, + want: nil, + wantErr: true, + }, + { + name: "wrong port", + args: args{portsStr: "85a38"}, + want: nil, + wantErr: true, + }, + { + name: "wrong port in first", + args: args{portsStr: "a-80"}, + want: nil, + wantErr: true, + }, + { + name: "wrong port in last", + args: args{portsStr: "80-85a38"}, + want: nil, + wantErr: true, + }, + { + name: "wrong port format", + args: args{portsStr: "80-85a38-3"}, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := expandPorts(tt.args.portsStr) + if (err != nil) != tt.wantErr { + t.Errorf("expandPorts() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("expandPorts() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_listMachinesInNamespace(t *testing.T) { + type args struct { + machines []Machine + namespace string + } + tests := []struct { + name string + args args + want []Machine + }{ + { + name: "1 machine in namespace", + args: args{ + machines: []Machine{ + {Namespace: Namespace{Name: "test"}}, + }, + namespace: "test", + }, + want: []Machine{ + {Namespace: Namespace{Name: "test"}}, + }, + }, + { + name: "3 machines, 2 in namespace", + args: args{ + machines: []Machine{ + {ID: 1, Namespace: Namespace{Name: "test"}}, + {ID: 2, Namespace: Namespace{Name: "foo"}}, + {ID: 3, Namespace: Namespace{Name: "foo"}}, + }, + namespace: "foo", + }, + want: []Machine{ + {ID: 2, Namespace: Namespace{Name: "foo"}}, + {ID: 3, Namespace: Namespace{Name: "foo"}}, + }, + }, + { + name: "5 machines, 0 in namespace", + args: args{ + machines: []Machine{ + {ID: 1, Namespace: Namespace{Name: "test"}}, + {ID: 2, Namespace: Namespace{Name: "foo"}}, + {ID: 3, Namespace: Namespace{Name: "foo"}}, + {ID: 4, Namespace: Namespace{Name: "foo"}}, + {ID: 5, Namespace: Namespace{Name: "foo"}}, + }, + namespace: "bar", + }, + want: []Machine{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := listMachinesInNamespace(tt.args.machines, tt.args.namespace); !reflect.DeepEqual(got, tt.want) { + t.Errorf("listMachinesInNamespace() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_expandAlias(t *testing.T) { + type args struct { + machines []Machine + aclPolicy ACLPolicy + alias string + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + { + name: "wildcard", + args: args{ + alias: "*", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.78.84.227")}}, + }, + aclPolicy: ACLPolicy{}, + }, + want: []string{"*"}, + wantErr: false, + }, + { + name: "simple group", + args: args{ + alias: "group:foo", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}}, + }, + aclPolicy: ACLPolicy{ + Groups: Groups{"group:foo": []string{"foo", "bar"}}, + }, + }, + want: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, + wantErr: false, + }, + { + name: "wrong group", + args: args{ + alias: "group:test", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}}, + }, + aclPolicy: ACLPolicy{ + Groups: Groups{"group:foo": []string{"foo", "bar"}}, + }, + }, + want: []string{}, + wantErr: true, + }, + { + name: "simple ipaddress", + args: args{ + alias: "10.0.0.3", + machines: []Machine{}, + aclPolicy: ACLPolicy{}, + }, + want: []string{"10.0.0.3"}, + wantErr: false, + }, + { + name: "private network", + args: args{ + alias: "homeNetwork", + machines: []Machine{}, + aclPolicy: ACLPolicy{ + Hosts: Hosts{"homeNetwork": netaddr.MustParseIPPrefix("192.168.1.0/24")}, + }, + }, + want: []string{"192.168.1.0/24"}, + wantErr: false, + }, + { + name: "simple host", + args: args{ + alias: "10.0.0.1", + machines: []Machine{}, + aclPolicy: ACLPolicy{}, + }, + want: []string{"10.0.0.1"}, + wantErr: false, + }, + { + name: "simple CIDR", + args: args{ + alias: "10.0.0.0/16", + machines: []Machine{}, + aclPolicy: ACLPolicy{}, + }, + want: []string{"10.0.0.0/16"}, + wantErr: false, + }, + { + name: "simple tag", + args: args{ + alias: "tag:test", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:test": []string{"foo"}}, + }, + }, + want: []string{"100.64.0.1", "100.64.0.2"}, + wantErr: false, + }, + { + name: "No tag defined", + args: args{ + alias: "tag:foo", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}}, + }, + aclPolicy: ACLPolicy{ + Groups: Groups{"group:foo": []string{"foo", "bar"}}, + TagOwners: TagOwners{"tag:test": []string{"group:foo"}}, + }, + }, + want: []string{}, + wantErr: true, + }, + { + name: "list host in namespace without correctly tagged servers", + args: args{ + alias: "foo", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:test": []string{"foo"}}, + }, + }, + want: []string{"100.64.0.4"}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := expandAlias(tt.args.machines, tt.args.aclPolicy, tt.args.alias) + if (err != nil) != tt.wantErr { + t.Errorf("expandAlias() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("expandAlias() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_excludeCorrectlyTaggedNodes(t *testing.T) { + type args struct { + aclPolicy ACLPolicy + nodes []Machine + namespace string + } + tests := []struct { + name string + args args + want []Machine + wantErr bool + }{ + { + name: "exclude nodes with valid tags", + args: args{ + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:test": []string{"foo"}}, + }, + nodes: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + namespace: "foo", + }, + want: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + wantErr: false, + }, + { + name: "all nodes have invalid tags, don't exclude them", + args: args{ + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:foo": []string{"foo"}}, + }, + nodes: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + namespace: "foo", + }, + want: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := excludeCorrectlyTaggedNodes(tt.args.aclPolicy, tt.args.nodes, tt.args.namespace) + if (err != nil) != tt.wantErr { + t.Errorf("excludeCorrectlyTaggedNodes() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/machine.go b/machine.go index 231c179..01fc927 100644 --- a/machine.go +++ b/machine.go @@ -119,6 +119,19 @@ func (machine Machine) isExpired() bool { return time.Now().UTC().After(*machine.Expiry) } +func (h *Headscale) ListAllMachines() ([]Machine, error) { + machines := []Machine{} + if err := h.db.Preload("AuthKey"). + Preload("AuthKey.Namespace"). + Preload("Namespace"). + Where("registered"). + Find(&machines).Error; err != nil { + return nil, err + } + + return machines, nil +} + func containsAddresses(inputs []string, addrs MachineAddresses) bool { for _, addr := range addrs.ToStringSlice() { if containsString(inputs, addr) {