Add helper function to create a unique givenname

This commit is contained in:
Kristoffer Dalby 2022-05-16 20:30:43 +02:00
parent f4873d9387
commit 177c21b294
3 changed files with 177 additions and 0 deletions

View file

@ -26,6 +26,7 @@ const (
) )
errCouldNotConvertMachineInterface = Error("failed to convert machine interface") errCouldNotConvertMachineInterface = Error("failed to convert machine interface")
errHostnameTooLong = Error("Hostname too long") errHostnameTooLong = Error("Hostname too long")
MachineGivenNameHashLength = 8
) )
const ( const (
@ -813,3 +814,32 @@ func (machine *Machine) RoutesToProto() *v1.Routes {
EnabledRoutes: ipPrefixToString(enabledRoutes), EnabledRoutes: ipPrefixToString(enabledRoutes),
} }
} }
func (h *Headscale) GenerateGivenName(suppliedName string) (string, error) {
// If a hostname is or will be longer than 63 chars after adding the hash,
// it needs to be trimmed.
trimmedHostnameLength := labelHostnameLength - MachineGivenNameHashLength - 2
normalizedHostname, err := NormalizeToFQDNRules(
suppliedName,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
return "", err
}
postfix, err := GenerateRandomStringDNSSafe(MachineGivenNameHashLength)
if err != nil {
return "", err
}
// Verify that that the new unique name is shorter than the maximum allowed
// DNS segment.
if len(normalizedHostname) <= trimmedHostnameLength {
normalizedHostname = fmt.Sprintf("%s-%s", normalizedHostname, postfix)
} else {
normalizedHostname = fmt.Sprintf("%s-%s", normalizedHostname[:trimmedHostnameLength], postfix)
}
return normalizedHostname, nil
}

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@ -654,3 +655,136 @@ func Test_getFilteredByACLPeers(t *testing.T) {
}) })
} }
} }
func TestHeadscale_GenerateGivenName(t *testing.T) {
type args struct {
suppliedName string
}
tests := []struct {
name string
h *Headscale
args args
want string
wantErr bool
}{
{
name: "simple machine name generation",
h: &Headscale{
cfg: Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
},
},
args: args{
suppliedName: "testmachine",
},
want: "testmachine",
wantErr: false,
},
{
name: "machine name with 53 chars",
h: &Headscale{
cfg: Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
},
},
args: args{
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
},
want: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
wantErr: false,
},
{
name: "machine name with 60 chars",
h: &Headscale{
cfg: Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
},
},
args: args{
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine1234567",
},
want: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
wantErr: false,
},
{
name: "machine name with 63 chars",
h: &Headscale{
cfg: Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
},
},
args: args{
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine1234567890",
},
want: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
wantErr: false,
},
{
name: "machine name with 64 chars",
h: &Headscale{
cfg: Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
},
},
args: args{
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine1234567891",
},
want: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
wantErr: false,
},
{
name: "machine name with 73 chars",
h: &Headscale{
cfg: Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
},
},
args: args{
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine12345678901234567890",
},
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.h.GenerateGivenName(tt.args.suppliedName)
if (err != nil) != tt.wantErr {
t.Errorf(
"Headscale.GenerateGivenName() error = %v, wantErr %v",
err,
tt.wantErr,
)
return
}
if tt.want != "" && strings.Contains(tt.want, got) {
t.Errorf(
"Headscale.GenerateGivenName() = %v, is not a substring of %v",
tt.want,
got,
)
}
if len(got) > labelHostnameLength {
t.Errorf(
"Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d",
got,
labelHostnameLength,
)
}
})
}
}

View file

@ -317,3 +317,16 @@ func GenerateRandomStringURLSafe(n int) (string, error) {
return base64.RawURLEncoding.EncodeToString(b), err return base64.RawURLEncoding.EncodeToString(b), err
} }
// GenerateRandomStringDNSSafe returns a DNS-safe
// securely generated random string.
// It will return an error if the system's secure random
// number generator fails to function correctly, in which
// case the caller should not continue.
func GenerateRandomStringDNSSafe(n int) (string, error) {
str, err := GenerateRandomStringURLSafe(n)
str = strings.ReplaceAll(str, "_", "-")
return str[:n], err
}