From 2675ff4b94011d7ba248391beee2965173a841dd Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 12 Jun 2023 15:59:05 +0200 Subject: [PATCH] make parse destination string into a func Signed-off-by: Kristoffer Dalby --- hscontrol/policy/acls.go | 63 +++++++++++++++++++---------------- hscontrol/policy/acls_test.go | 63 +++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 29 deletions(-) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index d667c72..d9ec3c7 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -375,9 +375,39 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( machines types.Machines, needsWildcard bool, ) ([]tailcfg.NetPortRange, error) { - var tokens []string + alias, port, err := parseDestination(dest) + if err != nil { + return nil, err + } - log.Trace().Str("destination", dest).Msg("generating policy destination") + expanded, err := pol.ExpandAlias( + machines, + alias, + ) + if err != nil { + return nil, err + } + ports, err := expandPorts(port, needsWildcard) + if err != nil { + return nil, err + } + + dests := []tailcfg.NetPortRange{} + for _, dest := range expanded.Prefixes() { + for _, port := range *ports { + pr := tailcfg.NetPortRange{ + IP: dest.String(), + Ports: port, + } + dests = append(dests, pr) + } + } + + return dests, nil +} + +func parseDestination(dest string) (string, string, error) { + var tokens []string // Check if there is a IPv4/6:Port combination, IPv6 has more than // three ":". @@ -397,7 +427,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() { log.Trace().Err(err).Msg("trying to parse as IPv6") - return nil, fmt.Errorf( + return "", "", fmt.Errorf( "failed to parse destination, tokens %v: %w", tokens, ErrInvalidPortFormat, @@ -407,8 +437,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( } } - log.Trace().Strs("tokens", tokens).Msg("generating policy destination") - var alias string // We can have here stuff like: // git-server:* @@ -424,30 +452,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) } - expanded, err := pol.ExpandAlias( - machines, - alias, - ) - if err != nil { - return nil, err - } - ports, err := expandPorts(tokens[len(tokens)-1], needsWildcard) - if err != nil { - return nil, err - } - - dests := []tailcfg.NetPortRange{} - for _, dest := range expanded.Prefixes() { - for _, port := range *ports { - pr := tailcfg.NetPortRange{ - IP: dest.String(), - Ports: port, - } - dests = append(dests, pr) - } - } - - return dests, nil + return alias, tokens[len(tokens)-1], nil } // parseProtocol reads the proto field of the ACL and generates a list of diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index e220297..94fdc66 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -2557,3 +2557,66 @@ func TestSSHRules(t *testing.T) { }) } } + +func TestParseDestination(t *testing.T) { + tests := []struct { + dest string + wantAlias string + wantPort string + }{ + { + dest: "git-server:*", + wantAlias: "git-server", + wantPort: "*", + }, + { + dest: "192.168.1.0/24:22", + wantAlias: "192.168.1.0/24", + wantPort: "22", + }, + { + dest: "192.168.1.1:22", + wantAlias: "192.168.1.1", + wantPort: "22", + }, + { + dest: "fd7a:115c:a1e0::2:22", + wantAlias: "fd7a:115c:a1e0::2", + wantPort: "22", + }, + { + dest: "fd7a:115c:a1e0::2/128:22", + wantAlias: "fd7a:115c:a1e0::2/128", + wantPort: "22", + }, + { + dest: "tag:montreal-webserver:80,443", + wantAlias: "tag:montreal-webserver", + wantPort: "80,443", + }, + { + dest: "tag:api-server:443", + wantAlias: "tag:api-server", + wantPort: "443", + }, + { + dest: "example-host-1:*", + wantAlias: "example-host-1", + wantPort: "*", + }, + } + + for _, tt := range tests { + t.Run(tt.dest, func(t *testing.T) { + alias, port, _ := parseDestination(tt.dest) + + if alias != tt.wantAlias { + t.Errorf("unexpected alias: want(%s) != got(%s)", tt.wantAlias, alias) + } + + if port != tt.wantPort { + t.Errorf("unexpected port: want(%s) != got(%s)", tt.wantPort, port) + } + }) + } +}