From 155cc072f752f87b4d3e460b73fc91015b16482d Mon Sep 17 00:00:00 2001
From: Kristoffer Dalby <kristoffer@tailscale.com>
Date: Mon, 19 Jun 2023 08:48:49 +0200
Subject: [PATCH] migrate last acl tests away from database

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
---
 hscontrol/db/acls_test.go     | 264 ----------------------------------
 hscontrol/db/machine_test.go  |   4 +-
 hscontrol/mapper/mapper.go    |   2 +-
 hscontrol/policy/acls.go      |   2 +-
 hscontrol/policy/acls_test.go | 253 +++++++++++++++++++++++++++++++-
 5 files changed, 254 insertions(+), 271 deletions(-)
 delete mode 100644 hscontrol/db/acls_test.go

diff --git a/hscontrol/db/acls_test.go b/hscontrol/db/acls_test.go
deleted file mode 100644
index 5c109c7..0000000
--- a/hscontrol/db/acls_test.go
+++ /dev/null
@@ -1,264 +0,0 @@
-package db
-
-import (
-	"net/netip"
-	"testing"
-
-	"github.com/google/go-cmp/cmp"
-	"github.com/juanfont/headscale/hscontrol/policy"
-	"github.com/juanfont/headscale/hscontrol/types"
-	"github.com/juanfont/headscale/hscontrol/util"
-	"github.com/stretchr/testify/assert"
-	"tailscale.com/tailcfg"
-)
-
-// TODO(kradalby):
-// Convert these tests to being non-database dependent and table driven. They are
-// very verbose, and dont really need the database.
-
-// this test should validate that we can expand a group in a TagOWner section and
-// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
-// the tag is matched in the Sources section.
-func TestValidExpandTagOwnersInSources(t *testing.T) {
-	hostInfo := tailcfg.Hostinfo{
-		OS:          "centos",
-		Hostname:    "testmachine",
-		RequestTags: []string{"tag:test"},
-	}
-
-	machine := types.Machine{
-		ID:          0,
-		MachineKey:  "foo",
-		NodeKey:     "bar",
-		DiscoKey:    "faa",
-		Hostname:    "testmachine",
-		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
-		UserID:      0,
-		User: types.User{
-			Name: "user1",
-		},
-		RegisterMethod: util.RegisterMethodAuthKey,
-		HostInfo:       types.HostInfo(hostInfo),
-	}
-
-	pol := &policy.ACLPolicy{
-		Groups:    policy.Groups{"group:test": []string{"user1", "user2"}},
-		TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
-		ACLs: []policy.ACL{
-			{
-				Action:       "accept",
-				Sources:      []string{"tag:test"},
-				Destinations: []string{"*:*"},
-			},
-		},
-	}
-
-	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
-	assert.NoError(t, err)
-
-	want := []tailcfg.FilterRule{
-		{
-			SrcIPs: []string{"100.64.0.1/32"},
-			DstPorts: []tailcfg.NetPortRange{
-				{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
-				{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
-			},
-		},
-	}
-
-	if diff := cmp.Diff(want, got); diff != "" {
-		t.Errorf("TestValidExpandTagOwnersInSources() unexpected result (-want +got):\n%s", diff)
-	}
-}
-
-// need a test with:
-// tag on a host that isn't owned by a tag owners. So the user
-// of the host should be valid.
-func TestInvalidTagValidUser(t *testing.T) {
-	hostInfo := tailcfg.Hostinfo{
-		OS:          "centos",
-		Hostname:    "testmachine",
-		RequestTags: []string{"tag:foo"},
-	}
-
-	machine := types.Machine{
-		ID:          1,
-		MachineKey:  "12345",
-		NodeKey:     "bar",
-		DiscoKey:    "faa",
-		Hostname:    "testmachine",
-		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
-		UserID:      1,
-		User: types.User{
-			Name: "user1",
-		},
-		RegisterMethod: util.RegisterMethodAuthKey,
-		HostInfo:       types.HostInfo(hostInfo),
-	}
-
-	pol := &policy.ACLPolicy{
-		TagOwners: policy.TagOwners{"tag:test": []string{"user1"}},
-		ACLs: []policy.ACL{
-			{
-				Action:       "accept",
-				Sources:      []string{"user1"},
-				Destinations: []string{"*:*"},
-			},
-		},
-	}
-
-	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
-	assert.NoError(t, err)
-
-	want := []tailcfg.FilterRule{
-		{
-			SrcIPs: []string{"100.64.0.1/32"},
-			DstPorts: []tailcfg.NetPortRange{
-				{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
-				{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
-			},
-		},
-	}
-
-	if diff := cmp.Diff(want, got); diff != "" {
-		t.Errorf("TestInvalidTagValidUser() unexpected result (-want +got):\n%s", diff)
-	}
-}
-
-// this test should validate that we can expand a group in a TagOWner section and
-// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
-// the tag is matched in the Destinations section.
-func TestValidExpandTagOwnersInDestinations(t *testing.T) {
-	hostInfo := tailcfg.Hostinfo{
-		OS:          "centos",
-		Hostname:    "testmachine",
-		RequestTags: []string{"tag:test"},
-	}
-
-	machine := types.Machine{
-		ID:          1,
-		MachineKey:  "12345",
-		NodeKey:     "bar",
-		DiscoKey:    "faa",
-		Hostname:    "testmachine",
-		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
-		UserID:      1,
-		User: types.User{
-			Name: "user1",
-		},
-		RegisterMethod: util.RegisterMethodAuthKey,
-		HostInfo:       types.HostInfo(hostInfo),
-	}
-
-	pol := &policy.ACLPolicy{
-		Groups:    policy.Groups{"group:test": []string{"user1", "user2"}},
-		TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
-		ACLs: []policy.ACL{
-			{
-				Action:       "accept",
-				Sources:      []string{"*"},
-				Destinations: []string{"tag:test:*"},
-			},
-		},
-	}
-
-	// rules, _, err := policy.GenerateFilterRules(pol, &machine, peers, false)
-	// c.Assert(err, check.IsNil)
-	//
-	// c.Assert(rules, check.HasLen, 1)
-	// c.Assert(rules[0].DstPorts, check.HasLen, 1)
-	// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
-
-	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
-	assert.NoError(t, err)
-
-	want := []tailcfg.FilterRule{
-		{
-			SrcIPs: []string{"0.0.0.0/0", "::/0"},
-			DstPorts: []tailcfg.NetPortRange{
-				{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{Last: 65535}},
-			},
-		},
-	}
-
-	if diff := cmp.Diff(want, got); diff != "" {
-		t.Errorf(
-			"TestValidExpandTagOwnersInDestinations() unexpected result (-want +got):\n%s",
-			diff,
-		)
-	}
-}
-
-// tag on a host is owned by a tag owner, the tag is valid.
-// an ACL rule is matching the tag to a user. It should not be valid since the
-// host should be tied to the tag now.
-func TestValidTagInvalidUser(t *testing.T) {
-	hostInfo := tailcfg.Hostinfo{
-		OS:          "centos",
-		Hostname:    "webserver",
-		RequestTags: []string{"tag:webapp"},
-	}
-
-	machine := types.Machine{
-		ID:          1,
-		MachineKey:  "12345",
-		NodeKey:     "bar",
-		DiscoKey:    "faa",
-		Hostname:    "webserver",
-		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
-		UserID:      1,
-		User: types.User{
-			Name: "user1",
-		},
-		RegisterMethod: util.RegisterMethodAuthKey,
-		HostInfo:       types.HostInfo(hostInfo),
-	}
-
-	hostInfo2 := tailcfg.Hostinfo{
-		OS:       "debian",
-		Hostname: "Hostname",
-	}
-
-	machine2 := types.Machine{
-		ID:          2,
-		MachineKey:  "56789",
-		NodeKey:     "bar2",
-		DiscoKey:    "faab",
-		Hostname:    "user",
-		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
-		UserID:      1,
-		User: types.User{
-			Name: "user1",
-		},
-		RegisterMethod: util.RegisterMethodAuthKey,
-		HostInfo:       types.HostInfo(hostInfo2),
-	}
-
-	pol := &policy.ACLPolicy{
-		TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
-		ACLs: []policy.ACL{
-			{
-				Action:       "accept",
-				Sources:      []string{"user1"},
-				Destinations: []string{"tag:webapp:80,443"},
-			},
-		},
-	}
-
-	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2})
-	assert.NoError(t, err)
-
-	want := []tailcfg.FilterRule{
-		{
-			SrcIPs: []string{"100.64.0.2/32"},
-			DstPorts: []tailcfg.NetPortRange{
-				{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
-				{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 443, Last: 443}},
-			},
-		},
-	}
-
-	if diff := cmp.Diff(want, got); diff != "" {
-		t.Errorf("TestValidTagInvalidUser() unexpected result (-want +got):\n%s", diff)
-	}
-}
diff --git a/hscontrol/db/machine_test.go b/hscontrol/db/machine_test.go
index f6d173c..2786a0d 100644
--- a/hscontrol/db/machine_test.go
+++ b/hscontrol/db/machine_test.go
@@ -292,10 +292,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
 	testPeers, err := db.ListPeers(testMachine)
 	c.Assert(err, check.IsNil)
 
-	adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers)
+	adminRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, adminMachine, adminPeers)
 	c.Assert(err, check.IsNil)
 
-	testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers)
+	testRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, testMachine, testPeers)
 	c.Assert(err, check.IsNil)
 
 	peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules)
diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go
index 658ae0f..5d5509b 100644
--- a/hscontrol/mapper/mapper.go
+++ b/hscontrol/mapper/mapper.go
@@ -95,7 +95,7 @@ func fullMapResponse(
 		return nil, err
 	}
 
-	rules, sshPolicy, err := policy.GenerateFilterRules(
+	rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
 		pol,
 		machine,
 		peers,
diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go
index dcf1ae3..bcdbb5d 100644
--- a/hscontrol/policy/acls.go
+++ b/hscontrol/policy/acls.go
@@ -117,7 +117,7 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
 // TODO(kradalby): This needs to be replace with something that generates
 // the rules as needed and not stores it on the global object, rules are
 // per node and that should be taken into account.
-func GenerateFilterRules(
+func GenerateFilterAndSSHRules(
 	policy *ACLPolicy,
 	machine *types.Machine,
 	peers types.Machines,
diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go
index 0808345..3995935 100644
--- a/hscontrol/policy/acls_test.go
+++ b/hscontrol/policy/acls_test.go
@@ -562,7 +562,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
 			},
 		},
 	}
-	_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
+	_, _, err := GenerateFilterAndSSHRules(pol, &types.Machine{}, types.Machines{})
 	c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
 }
 
@@ -581,7 +581,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
 			},
 		},
 	}
-	_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
+	_, _, err := GenerateFilterAndSSHRules(pol, &types.Machine{}, types.Machines{})
 	c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
 }
 
@@ -597,7 +597,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
 		},
 	}
 
-	_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
+	_, _, err := GenerateFilterAndSSHRules(pol, &types.Machine{}, types.Machines{})
 	c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
 }
 
@@ -2969,3 +2969,250 @@ func TestParseDestination(t *testing.T) {
 		})
 	}
 }
+
+// this test should validate that we can expand a group in a TagOWner section and
+// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
+// the tag is matched in the Sources section.
+func TestValidExpandTagOwnersInSources(t *testing.T) {
+	hostInfo := tailcfg.Hostinfo{
+		OS:          "centos",
+		Hostname:    "testmachine",
+		RequestTags: []string{"tag:test"},
+	}
+
+	machine := types.Machine{
+		ID:          0,
+		MachineKey:  "foo",
+		NodeKey:     "bar",
+		DiscoKey:    "faa",
+		Hostname:    "testmachine",
+		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
+		UserID:      0,
+		User: types.User{
+			Name: "user1",
+		},
+		RegisterMethod: util.RegisterMethodAuthKey,
+		HostInfo:       types.HostInfo(hostInfo),
+	}
+
+	pol := &ACLPolicy{
+		Groups:    Groups{"group:test": []string{"user1", "user2"}},
+		TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}},
+		ACLs: []ACL{
+			{
+				Action:       "accept",
+				Sources:      []string{"tag:test"},
+				Destinations: []string{"*:*"},
+			},
+		},
+	}
+
+	got, _, err := GenerateFilterAndSSHRules(pol, &machine, types.Machines{})
+	assert.NoError(t, err)
+
+	want := []tailcfg.FilterRule{
+		{
+			SrcIPs: []string{"100.64.0.1/32"},
+			DstPorts: []tailcfg.NetPortRange{
+				{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
+				{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
+			},
+		},
+	}
+
+	if diff := cmp.Diff(want, got); diff != "" {
+		t.Errorf("TestValidExpandTagOwnersInSources() unexpected result (-want +got):\n%s", diff)
+	}
+}
+
+// need a test with:
+// tag on a host that isn't owned by a tag owners. So the user
+// of the host should be valid.
+func TestInvalidTagValidUser(t *testing.T) {
+	hostInfo := tailcfg.Hostinfo{
+		OS:          "centos",
+		Hostname:    "testmachine",
+		RequestTags: []string{"tag:foo"},
+	}
+
+	machine := types.Machine{
+		ID:          1,
+		MachineKey:  "12345",
+		NodeKey:     "bar",
+		DiscoKey:    "faa",
+		Hostname:    "testmachine",
+		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
+		UserID:      1,
+		User: types.User{
+			Name: "user1",
+		},
+		RegisterMethod: util.RegisterMethodAuthKey,
+		HostInfo:       types.HostInfo(hostInfo),
+	}
+
+	pol := &ACLPolicy{
+		TagOwners: TagOwners{"tag:test": []string{"user1"}},
+		ACLs: []ACL{
+			{
+				Action:       "accept",
+				Sources:      []string{"user1"},
+				Destinations: []string{"*:*"},
+			},
+		},
+	}
+
+	got, _, err := GenerateFilterAndSSHRules(pol, &machine, types.Machines{})
+	assert.NoError(t, err)
+
+	want := []tailcfg.FilterRule{
+		{
+			SrcIPs: []string{"100.64.0.1/32"},
+			DstPorts: []tailcfg.NetPortRange{
+				{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
+				{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
+			},
+		},
+	}
+
+	if diff := cmp.Diff(want, got); diff != "" {
+		t.Errorf("TestInvalidTagValidUser() unexpected result (-want +got):\n%s", diff)
+	}
+}
+
+// this test should validate that we can expand a group in a TagOWner section and
+// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
+// the tag is matched in the Destinations section.
+func TestValidExpandTagOwnersInDestinations(t *testing.T) {
+	hostInfo := tailcfg.Hostinfo{
+		OS:          "centos",
+		Hostname:    "testmachine",
+		RequestTags: []string{"tag:test"},
+	}
+
+	machine := types.Machine{
+		ID:          1,
+		MachineKey:  "12345",
+		NodeKey:     "bar",
+		DiscoKey:    "faa",
+		Hostname:    "testmachine",
+		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
+		UserID:      1,
+		User: types.User{
+			Name: "user1",
+		},
+		RegisterMethod: util.RegisterMethodAuthKey,
+		HostInfo:       types.HostInfo(hostInfo),
+	}
+
+	pol := &ACLPolicy{
+		Groups:    Groups{"group:test": []string{"user1", "user2"}},
+		TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}},
+		ACLs: []ACL{
+			{
+				Action:       "accept",
+				Sources:      []string{"*"},
+				Destinations: []string{"tag:test:*"},
+			},
+		},
+	}
+
+	// rules, _, err := GenerateFilterRules(pol, &machine, peers, false)
+	// c.Assert(err, check.IsNil)
+	//
+	// c.Assert(rules, check.HasLen, 1)
+	// c.Assert(rules[0].DstPorts, check.HasLen, 1)
+	// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
+
+	got, _, err := GenerateFilterAndSSHRules(pol, &machine, types.Machines{})
+	assert.NoError(t, err)
+
+	want := []tailcfg.FilterRule{
+		{
+			SrcIPs: []string{"0.0.0.0/0", "::/0"},
+			DstPorts: []tailcfg.NetPortRange{
+				{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{Last: 65535}},
+			},
+		},
+	}
+
+	if diff := cmp.Diff(want, got); diff != "" {
+		t.Errorf(
+			"TestValidExpandTagOwnersInDestinations() unexpected result (-want +got):\n%s",
+			diff,
+		)
+	}
+}
+
+// tag on a host is owned by a tag owner, the tag is valid.
+// an ACL rule is matching the tag to a user. It should not be valid since the
+// host should be tied to the tag now.
+func TestValidTagInvalidUser(t *testing.T) {
+	hostInfo := tailcfg.Hostinfo{
+		OS:          "centos",
+		Hostname:    "webserver",
+		RequestTags: []string{"tag:webapp"},
+	}
+
+	machine := types.Machine{
+		ID:          1,
+		MachineKey:  "12345",
+		NodeKey:     "bar",
+		DiscoKey:    "faa",
+		Hostname:    "webserver",
+		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
+		UserID:      1,
+		User: types.User{
+			Name: "user1",
+		},
+		RegisterMethod: util.RegisterMethodAuthKey,
+		HostInfo:       types.HostInfo(hostInfo),
+	}
+
+	hostInfo2 := tailcfg.Hostinfo{
+		OS:       "debian",
+		Hostname: "Hostname",
+	}
+
+	machine2 := types.Machine{
+		ID:          2,
+		MachineKey:  "56789",
+		NodeKey:     "bar2",
+		DiscoKey:    "faab",
+		Hostname:    "user",
+		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
+		UserID:      1,
+		User: types.User{
+			Name: "user1",
+		},
+		RegisterMethod: util.RegisterMethodAuthKey,
+		HostInfo:       types.HostInfo(hostInfo2),
+	}
+
+	pol := &ACLPolicy{
+		TagOwners: TagOwners{"tag:webapp": []string{"user1"}},
+		ACLs: []ACL{
+			{
+				Action:       "accept",
+				Sources:      []string{"user1"},
+				Destinations: []string{"tag:webapp:80,443"},
+			},
+		},
+	}
+
+	got, _, err := GenerateFilterAndSSHRules(pol, &machine, types.Machines{machine2})
+	assert.NoError(t, err)
+
+	want := []tailcfg.FilterRule{
+		{
+			SrcIPs: []string{"100.64.0.2/32"},
+			DstPorts: []tailcfg.NetPortRange{
+				{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
+				{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 443, Last: 443}},
+			},
+		},
+	}
+
+	if diff := cmp.Diff(want, got); diff != "" {
+		t.Errorf("TestValidTagInvalidUser() unexpected result (-want +got):\n%s", diff)
+	}
+}