make GenerateFilterRules take machine and peers

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-08 19:10:09 +02:00 committed by Kristoffer Dalby
parent 35770278f7
commit db6cf4ac0a
6 changed files with 291 additions and 316 deletions

View file

@ -2,10 +2,13 @@ package db
import ( import (
"net/netip" "net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -77,7 +80,7 @@ func (s *Suite) TestSshRules(c *check.C) {
}, },
} }
_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false) _, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, &machine, types.Machines{}, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(sshPolicy, check.NotNil) c.Assert(sshPolicy, check.NotNil)
@ -94,15 +97,7 @@ func (s *Suite) TestSshRules(c *check.C) {
// this test should validate that we can expand a group in a TagOWner section and // this test should validate that we can expand a group in a TagOWner section and
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. // match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
// the tag is matched in the Sources section. // the tag is matched in the Sources section.
func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { func TestValidExpandTagOwnersInSources(t *testing.T) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "testmachine")
c.Assert(err, check.NotNil)
hostInfo := tailcfg.Hostinfo{ hostInfo := tailcfg.Hostinfo{
OS: "centos", OS: "centos",
Hostname: "testmachine", Hostname: "testmachine",
@ -116,13 +111,13 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testmachine", Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID, UserID: 0,
User: types.User{
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo), HostInfo: types.HostInfo(hostInfo),
} }
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{ pol := &policy.ACLPolicy{
Groups: policy.Groups{"group:test": []string{"user1", "user2"}}, Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
@ -136,85 +131,28 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
}, },
} }
machines, err := db.ListMachines() got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
c.Assert(err, check.IsNil) assert.NoError(t, err)
rules, _, err := policy.GenerateFilterRules(pol, machines, false) want := []tailcfg.FilterRule{
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
}
// this test should validate that we can expand a group in a TagOWner section and
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
// the tag is matched in the Destinations section.
func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "testmachine")
c.Assert(err, check.NotNil)
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "testmachine",
RequestTags: []string{"tag:test"},
}
machine := types.Machine{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
ACLs: []policy.ACL{
{ {
Action: "accept", SrcIPs: []string{"100.64.0.1/32"},
Sources: []string{"*"}, DstPorts: []tailcfg.NetPortRange{
Destinations: []string{"tag:test:*"}, {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
}, },
}, },
} }
machines, err := db.ListMachines() if diff := cmp.Diff(want, got); diff != "" {
c.Assert(err, check.IsNil) t.Errorf("TestValidExpandTagOwnersInSources() unexpected result (-want +got):\n%s", diff)
}
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
} }
// need a test with: // need a test with:
// tag on a host that isn't owned by a tag owners. So the user // tag on a host that isn't owned by a tag owners. So the user
// of the host should be valid. // of the host should be valid.
func (s *Suite) TestInvalidTagValidUser(c *check.C) { func TestInvalidTagValidUser(t *testing.T) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "testmachine")
c.Assert(err, check.NotNil)
hostInfo := tailcfg.Hostinfo{ hostInfo := tailcfg.Hostinfo{
OS: "centos", OS: "centos",
Hostname: "testmachine", Hostname: "testmachine",
@ -228,13 +166,13 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testmachine", Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID, UserID: 1,
User: types.User{
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo), HostInfo: types.HostInfo(hostInfo),
} }
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{ pol := &policy.ACLPolicy{
TagOwners: policy.TagOwners{"tag:test": []string{"user1"}}, TagOwners: policy.TagOwners{"tag:test": []string{"user1"}},
@ -247,190 +185,38 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) {
}, },
} }
machines, err := db.ListMachines() got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
c.Assert(err, check.IsNil) assert.NoError(t, err)
rules, _, err := policy.GenerateFilterRules(pol, machines, false) want := []tailcfg.FilterRule{
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
}
// tag on a host is owned by a tag owner, the tag is valid.
// an ACL rule is matching the tag to a user. It should not be valid since the
// host should be tied to the tag now.
func (s *Suite) TestValidTagInvalidUser(c *check.C) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "webserver")
c.Assert(err, check.NotNil)
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "webserver",
RequestTags: []string{"tag:webapp"},
}
machine := types.Machine{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "webserver",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "user")
hostInfo2 := tailcfg.Hostinfo{
OS: "debian",
Hostname: "Hostname",
}
c.Assert(err, check.NotNil)
machine = types.Machine{
ID: 2,
MachineKey: "56789",
NodeKey: "bar2",
DiscoKey: "faab",
Hostname: "user",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo2),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{
TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
ACLs: []policy.ACL{
{ {
Action: "accept", SrcIPs: []string{"100.64.0.1/32"},
Sources: []string{"user1"}, DstPorts: []tailcfg.NetPortRange{
Destinations: []string{"tag:webapp:80,443"}, {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
}, },
}, },
} }
machines, err := db.ListMachines() if diff := cmp.Diff(want, got); diff != "" {
c.Assert(err, check.IsNil) t.Errorf("TestInvalidTagValidUser() unexpected result (-want +got):\n%s", diff)
}
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32")
c.Assert(rules[0].DstPorts, check.HasLen, 2)
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32")
} }
func (s *Suite) TestPortUser(c *check.C) { func TestPortGroup(t *testing.T) {
user, err := db.CreateUser("testuser")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("testuser", "testmachine")
c.Assert(err, check.NotNil)
ips, _ := db.getAvailableIPs()
machine := types.Machine{
ID: 0,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: ips,
AuthKeyID: uint(pak.ID),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
acl := []byte(`
{
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"testuser",
],
"dst": [
"host-1:*",
],
},
],
}
`)
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert(len(ips), check.Equals, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
}
func (s *Suite) TestPortGroup(c *check.C) {
user, err := db.CreateUser("testuser")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("testuser", "testmachine")
c.Assert(err, check.NotNil)
ips, _ := db.getAvailableIPs()
machine := types.Machine{ machine := types.Machine{
ID: 0, ID: 0,
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Hostname: "testmachine", Hostname: "testmachine",
UserID: user.ID, UserID: 0,
User: types.User{
Name: "testuser",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: ips, IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.5")},
AuthKeyID: uint(pak.ID),
} }
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
acl := []byte(` acl := []byte(`
{ {
@ -459,22 +245,211 @@ func (s *Suite) TestPortGroup(c *check.C) {
} }
`) `)
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil) assert.NoError(t, err)
machines, err := db.ListMachines() got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
c.Assert(err, check.IsNil) assert.NoError(t, err)
rules, _, err := policy.GenerateFilterRules(pol, machines, false) want := []tailcfg.FilterRule{
c.Assert(err, check.IsNil) {
SrcIPs: []string{"100.64.0.5/32"},
c.Assert(rules, check.NotNil) DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{Last: 65535}},
c.Assert(rules, check.HasLen, 1) },
c.Assert(rules[0].DstPorts, check.HasLen, 1) },
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) }
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert(rules[0].SrcIPs, check.HasLen, 1) if diff := cmp.Diff(want, got); diff != "" {
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") t.Errorf("TestPortGroup() unexpected result (-want +got):\n%s", diff)
c.Assert(len(ips), check.Equals, 1) }
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") }
func TestPortUser(t *testing.T) {
machine := types.Machine{
ID: 0,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: 0,
User: types.User{
Name: "testuser",
},
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.9")},
}
acl := []byte(`
{
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"testuser",
],
"dst": [
"host-1:*",
],
},
],
}
`)
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
assert.NoError(t, err)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
assert.NoError(t, err)
want := []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.9/32"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{Last: 65535}},
},
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("TestPortUser() unexpected result (-want +got):\n%s", diff)
}
}
// this test should validate that we can expand a group in a TagOWner section and
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
// the tag is matched in the Destinations section.
func TestValidExpandTagOwnersInDestinations(t *testing.T) {
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "testmachine",
RequestTags: []string{"tag:test"},
}
machine := types.Machine{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: 1,
User: types.User{
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
HostInfo: types.HostInfo(hostInfo),
}
pol := &policy.ACLPolicy{
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"tag:test:*"},
},
},
}
// rules, _, err := policy.GenerateFilterRules(pol, &machine, peers, false)
// c.Assert(err, check.IsNil)
//
// c.Assert(rules, check.HasLen, 1)
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
assert.NoError(t, err)
want := []tailcfg.FilterRule{
{
SrcIPs: []string{"0.0.0.0/0", "::/0"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{Last: 65535}},
},
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("TestValidExpandTagOwnersInDestinations() unexpected result (-want +got):\n%s", diff)
}
}
// tag on a host is owned by a tag owner, the tag is valid.
// an ACL rule is matching the tag to a user. It should not be valid since the
// host should be tied to the tag now.
func TestValidTagInvalidUser(t *testing.T) {
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "webserver",
RequestTags: []string{"tag:webapp"},
}
machine := types.Machine{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "webserver",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: 1,
User: types.User{
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
HostInfo: types.HostInfo(hostInfo),
}
hostInfo2 := tailcfg.Hostinfo{
OS: "debian",
Hostname: "Hostname",
}
machine2 := types.Machine{
ID: 2,
MachineKey: "56789",
NodeKey: "bar2",
DiscoKey: "faab",
Hostname: "user",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
UserID: 1,
User: types.User{
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
HostInfo: types.HostInfo(hostInfo2),
}
pol := &policy.ACLPolicy{
TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"user1"},
Destinations: []string{"tag:webapp:80,443"},
},
},
}
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}, false)
assert.NoError(t, err)
want := []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.2/32"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 443, Last: 443}},
},
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("TestValidTagInvalidUser() unexpected result (-want +got):\n%s", diff)
}
} }

View file

@ -287,14 +287,20 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User) c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
machines, err := db.ListMachines() adminPeers, err := db.ListPeers(adminMachine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false) testPeers, err := db.ListPeers(testMachine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
peersOfTestMachine := policy.FilterMachinesByACL(testMachine, machines, aclRules) adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers, false)
peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, machines, aclRules) c.Assert(err, check.IsNil)
testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers, false)
c.Assert(err, check.IsNil)
peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules)
peersOfTestMachine := policy.FilterMachinesByACL(testMachine, testPeers, testRules)
c.Log(peersOfTestMachine) c.Log(peersOfTestMachine)
c.Assert(len(peersOfTestMachine), check.Equals, 9) c.Assert(len(peersOfTestMachine), check.Equals, 9)

View file

@ -101,8 +101,8 @@ func fullMapResponse(
rules, sshPolicy, err := policy.GenerateFilterRules( rules, sshPolicy, err := policy.GenerateFilterRules(
pol, pol,
// The policy is currently calculated for the entire Headscale network machine,
append(peers, *machine), peers,
stripEmailDomain, stripEmailDomain,
) )
if err != nil { if err != nil {

View file

@ -360,7 +360,7 @@ func Test_fullMapResponse(t *testing.T) {
CollectServices: "false", CollectServices: "false",
PacketFilter: []tailcfg.FilterRule{}, PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
SSHPolicy: nil, SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{}, ControlTime: &time.Time{},
Debug: &tailcfg.Debug{ Debug: &tailcfg.Debug{
DisableLogTail: true, DisableLogTail: true,
@ -393,7 +393,7 @@ func Test_fullMapResponse(t *testing.T) {
CollectServices: "false", CollectServices: "false",
PacketFilter: []tailcfg.FilterRule{}, PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
SSHPolicy: nil, SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{}, ControlTime: &time.Time{},
Debug: &tailcfg.Debug{ Debug: &tailcfg.Debug{
DisableLogTail: true, DisableLogTail: true,
@ -442,7 +442,7 @@ func Test_fullMapResponse(t *testing.T) {
}, },
}, },
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
SSHPolicy: nil, SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{}, ControlTime: &time.Time{},
Debug: &tailcfg.Debug{ Debug: &tailcfg.Debug{
DisableLogTail: true, DisableLogTail: true,

View file

@ -18,7 +18,6 @@ import (
"github.com/tailscale/hujson" "github.com/tailscale/hujson"
"go4.org/netipx" "go4.org/netipx"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"tailscale.com/envknob"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -54,8 +53,6 @@ const (
ProtocolFC = 133 // Fibre Channel ProtocolFC = 133 // Fibre Channel
) )
var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH")
// LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules. // LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules.
func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) { func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) {
log.Debug(). log.Debug().
@ -122,7 +119,8 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
// per node and that should be taken into account. // per node and that should be taken into account.
func GenerateFilterRules( func GenerateFilterRules(
policy *ACLPolicy, policy *ACLPolicy,
machines types.Machines, machine *types.Machine,
peers types.Machines,
stripEmailDomain bool, stripEmailDomain bool,
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
// If there is no policy defined, we default to allow all // If there is no policy defined, we default to allow all
@ -130,7 +128,7 @@ func GenerateFilterRules(
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
} }
rules, err := policy.generateFilterRules(machines, stripEmailDomain) rules, err := policy.generateFilterRules(append(peers, *machine), stripEmailDomain)
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
@ -138,8 +136,7 @@ func GenerateFilterRules(
log.Trace().Interface("ACL", rules).Msg("ACL rules generated") log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
var sshPolicy *tailcfg.SSHPolicy var sshPolicy *tailcfg.SSHPolicy
if featureEnableSSH() { sshRules, err := generateSSHRules(policy, append(peers, *machine), stripEmailDomain)
sshRules, err := generateSSHRules(policy, machines, stripEmailDomain)
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
@ -148,9 +145,6 @@ func GenerateFilterRules(
sshPolicy = &tailcfg.SSHPolicy{} sshPolicy = &tailcfg.SSHPolicy{}
} }
sshPolicy.Rules = sshRules sshPolicy.Rules = sshRules
} else if policy != nil && len(policy.SSHs) > 0 {
log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating")
}
return rules, sshPolicy, nil return rules, sshPolicy, nil
} }

View file

@ -245,7 +245,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterRules(pol, types.Machines{}, false) _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
} }
@ -264,7 +264,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterRules(pol, types.Machines{}, false) _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
} }
@ -280,7 +280,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
}, },
} }
_, _, err := GenerateFilterRules(pol, types.Machines{}, false) _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
} }