make generateFilterRules take machine and peers

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-12 11:18:58 +02:00 committed by Kristoffer Dalby
parent 9c425a1c08
commit 161243c787
2 changed files with 24 additions and 19 deletions

View file

@ -128,7 +128,7 @@ func GenerateFilterRules(
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
} }
rules, err := policy.generateFilterRules(append(peers, *machine), stripEmailDomain) rules, err := policy.generateFilterRules(machine, peers, stripEmailDomain)
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
@ -152,10 +152,12 @@ func GenerateFilterRules(
// generateFilterRules takes a set of machines and an ACLPolicy and generates a // generateFilterRules takes a set of machines and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients. // set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *ACLPolicy) generateFilterRules( func (pol *ACLPolicy) generateFilterRules(
machines types.Machines, machine *types.Machine,
peers types.Machines,
stripEmailDomain bool, stripEmailDomain bool,
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{} rules := []tailcfg.FilterRule{}
machines := append(peers, *machine)
for index, acl := range pol.ACLs { for index, acl := range pol.ACLs {
if acl.Action != "accept" { if acl.Action != "accept" {

View file

@ -199,7 +199,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(pol.ACLs, check.HasLen, 6)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := pol.generateFilterRules(types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil) c.Assert(rules, check.IsNil)
} }
@ -230,7 +230,7 @@ func (s *Suite) TestBasicRule(c *check.C) {
pol, err := LoadACLPolicyFromBytes(acl, "hujson") pol, err := LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := pol.generateFilterRules(types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
} }
@ -310,7 +310,7 @@ func (s *Suite) TestPortRange(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil) c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -366,7 +366,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil) c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -401,7 +401,7 @@ func (s *Suite) TestPortWildcard(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil) c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -428,7 +428,7 @@ acls:
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil) c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -459,7 +459,7 @@ acls:
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil) c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -1620,7 +1620,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
pol ACLPolicy pol ACLPolicy
} }
type args struct { type args struct {
machines types.Machines machine types.Machine
peers types.Machines
stripEmailDomain bool stripEmailDomain bool
} }
tests := []struct { tests := []struct {
@ -1651,7 +1652,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
}, },
}, },
args: args{ args: args{
machines: types.Machines{}, machine: types.Machine{},
peers: types.Machines{},
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -1691,14 +1693,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
}, },
}, },
args: args{ args: args{
machines: types.Machines{ machine: types.Machine{
{
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("100.64.0.1"),
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
}, },
User: types.User{Name: "mickael"}, User: types.User{Name: "mickael"},
}, },
peers: types.Machines{
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("100.64.0.2"),
@ -1739,7 +1741,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.field.pol.generateFilterRules( got, err := tt.field.pol.generateFilterRules(
tt.args.machines, &tt.args.machine,
tt.args.peers,
tt.args.stripEmailDomain, tt.args.stripEmailDomain,
) )
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {