create DB struct
This is step one in detaching the Database layer from Headscale (h). The ultimate goal is to have all function that does database operations in its own package, and keep the business logic and writing separate. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b01f1f1867
commit
14e29a7bee
48 changed files with 1731 additions and 1572 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -41,3 +41,5 @@ integration_test/etc/config.dump.yaml
|
||||||
# MkDocs
|
# MkDocs
|
||||||
.cache
|
.cache
|
||||||
/site
|
/site
|
||||||
|
|
||||||
|
__debug_bin
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/prometheus/common/model"
|
"github.com/prometheus/common/model"
|
||||||
"github.com/pterm/pterm"
|
"github.com/pterm/pterm"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -83,7 +83,7 @@ var listAPIKeys = &cobra.Command{
|
||||||
}
|
}
|
||||||
|
|
||||||
tableData = append(tableData, []string{
|
tableData = append(tableData, []string{
|
||||||
strconv.FormatUint(key.GetId(), hscontrol.Base10),
|
strconv.FormatUint(key.GetId(), util.Base10),
|
||||||
key.GetPrefix(),
|
key.GetPrefix(),
|
||||||
expiration,
|
expiration,
|
||||||
key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
|
key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
@ -93,7 +93,7 @@ var createNodeCmd = &cobra.Command{
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !hscontrol.NodePublicKeyRegex.Match([]byte(machineKey)) {
|
if !util.NodePublicKeyRegex.Match([]byte(machineKey)) {
|
||||||
err = errPreAuthKeyMalformed
|
err = errPreAuthKeyMalformed
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
|
|
||||||
survey "github.com/AlecAivazis/survey/v2"
|
survey "github.com/AlecAivazis/survey/v2"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/pterm/pterm"
|
"github.com/pterm/pterm"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
@ -529,7 +529,7 @@ func nodesToPtables(
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
var machineKey key.MachinePublic
|
||||||
err := machineKey.UnmarshalText(
|
err := machineKey.UnmarshalText(
|
||||||
[]byte(hscontrol.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
machineKey = key.MachinePublic{}
|
machineKey = key.MachinePublic{}
|
||||||
|
@ -537,7 +537,7 @@ func nodesToPtables(
|
||||||
|
|
||||||
var nodeKey key.NodePublic
|
var nodeKey key.NodePublic
|
||||||
err = nodeKey.UnmarshalText(
|
err = nodeKey.UnmarshalText(
|
||||||
[]byte(hscontrol.NodePublicKeyEnsurePrefix(machine.NodeKey)),
|
[]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -596,7 +596,7 @@ func nodesToPtables(
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeData := []string{
|
nodeData := []string{
|
||||||
strconv.FormatUint(machine.Id, hscontrol.Base10),
|
strconv.FormatUint(machine.Id, util.Base10),
|
||||||
machine.Name,
|
machine.Name,
|
||||||
machine.GetGivenName(),
|
machine.GetGivenName(),
|
||||||
machineKey.ShortString(),
|
machineKey.ShortString(),
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
package cli
|
package cli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
survey "github.com/AlecAivazis/survey/v2"
|
survey "github.com/AlecAivazis/survey/v2"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
|
||||||
"github.com/pterm/pterm"
|
"github.com/pterm/pterm"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
@ -20,9 +20,7 @@ func init() {
|
||||||
userCmd.AddCommand(renameUserCmd)
|
userCmd.AddCommand(renameUserCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
var errMissingParameter = errors.New("missing parameters")
|
||||||
errMissingParameter = hscontrol.Error("missing parameters")
|
|
||||||
)
|
|
||||||
|
|
||||||
var userCmd = &cobra.Command{
|
var userCmd = &cobra.Command{
|
||||||
Use: "users",
|
Use: "users",
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
@ -39,7 +40,7 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) {
|
||||||
// We are doing this here, as in the future could be cool to have it also hot-reload
|
// We are doing this here, as in the future could be cool to have it also hot-reload
|
||||||
|
|
||||||
if cfg.ACL.PolicyPath != "" {
|
if cfg.ACL.PolicyPath != "" {
|
||||||
aclPath := hscontrol.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath)
|
aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath)
|
||||||
err = app.LoadACLPolicyFromPath(aclPath)
|
err = app.LoadACLPolicyFromPath(aclPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().
|
log.Fatal().
|
||||||
|
@ -98,7 +99,7 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
|
||||||
grpcOptions = append(
|
grpcOptions = append(
|
||||||
grpcOptions,
|
grpcOptions,
|
||||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
grpc.WithContextDialer(hscontrol.GrpcSocketDialer),
|
grpc.WithContextDialer(util.GrpcSocketDialer),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
// If we are not connecting to a local server, require an API key for authentication
|
// If we are not connecting to a local server, require an API key for authentication
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
)
|
)
|
||||||
|
@ -64,7 +65,7 @@ func (*Suite) TestConfigFileLoading(c *check.C) {
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
||||||
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
|
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
|
||||||
c.Assert(
|
c.Assert(
|
||||||
hscontrol.GetFileMode("unix_socket_permission"),
|
util.GetFileMode("unix_socket_permission"),
|
||||||
check.Equals,
|
check.Equals,
|
||||||
fs.FileMode(0o770),
|
fs.FileMode(0o770),
|
||||||
)
|
)
|
||||||
|
@ -107,7 +108,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
||||||
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
|
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
|
||||||
c.Assert(
|
c.Assert(
|
||||||
hscontrol.GetFileMode("unix_socket_permission"),
|
util.GetFileMode("unix_socket_permission"),
|
||||||
check.Equals,
|
check.Equals,
|
||||||
fs.FileMode(0o770),
|
fs.FileMode(0o770),
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/tailscale/hujson"
|
"github.com/tailscale/hujson"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
|
@ -20,21 +21,16 @@ import (
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
errEmptyPolicy = Error("empty policy")
|
errEmptyPolicy = errors.New("empty policy")
|
||||||
errInvalidAction = Error("invalid action")
|
errInvalidAction = errors.New("invalid action")
|
||||||
errInvalidGroup = Error("invalid group")
|
errInvalidGroup = errors.New("invalid group")
|
||||||
errInvalidTag = Error("invalid tag")
|
errInvalidTag = errors.New("invalid tag")
|
||||||
errInvalidPortFormat = Error("invalid port format")
|
errInvalidPortFormat = errors.New("invalid port format")
|
||||||
errWildcardIsNeeded = Error("wildcard as port is required for the protocol")
|
errWildcardIsNeeded = errors.New("wildcard as port is required for the protocol")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Base8 = 8
|
|
||||||
Base10 = 10
|
|
||||||
BitSize16 = 16
|
|
||||||
BitSize32 = 32
|
|
||||||
BitSize64 = 64
|
|
||||||
portRangeBegin = 0
|
portRangeBegin = 0
|
||||||
portRangeEnd = 65535
|
portRangeEnd = 65535
|
||||||
expectedTokenItems = 2
|
expectedTokenItems = 2
|
||||||
|
@ -123,7 +119,7 @@ func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) UpdateACLRules() error {
|
func (h *Headscale) UpdateACLRules() error {
|
||||||
machines, err := h.ListMachines()
|
machines, err := h.db.ListMachines()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -230,7 +226,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
||||||
return nil, errEmptyPolicy
|
return nil, errEmptyPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
machines, err := h.ListMachines()
|
machines, err := h.db.ListMachines()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -570,7 +566,7 @@ func excludeCorrectlyTaggedNodes(
|
||||||
for tag := range aclPolicy.TagOwners {
|
for tag := range aclPolicy.TagOwners {
|
||||||
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
|
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
|
||||||
ns := append(owners, user)
|
ns := append(owners, user)
|
||||||
if contains(ns, user) {
|
if util.StringOrPrefixListContains(ns, user) {
|
||||||
tags = append(tags, tag)
|
tags = append(tags, tag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -580,7 +576,7 @@ func excludeCorrectlyTaggedNodes(
|
||||||
|
|
||||||
found := false
|
found := false
|
||||||
for _, t := range hi.RequestTags {
|
for _, t := range hi.RequestTags {
|
||||||
if contains(tags, t) {
|
if util.StringOrPrefixListContains(tags, t) {
|
||||||
found = true
|
found = true
|
||||||
|
|
||||||
break
|
break
|
||||||
|
@ -614,7 +610,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
|
||||||
rang := strings.Split(portStr, "-")
|
rang := strings.Split(portStr, "-")
|
||||||
switch len(rang) {
|
switch len(rang) {
|
||||||
case 1:
|
case 1:
|
||||||
port, err := strconv.ParseUint(rang[0], Base10, BitSize16)
|
port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -624,11 +620,11 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
|
||||||
})
|
})
|
||||||
|
|
||||||
case expectedTokenItems:
|
case expectedTokenItems:
|
||||||
start, err := strconv.ParseUint(rang[0], Base10, BitSize16)
|
start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
last, err := strconv.ParseUint(rang[1], Base10, BitSize16)
|
last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -754,7 +750,7 @@ func (pol *ACLPolicy) getIPsFromTag(
|
||||||
|
|
||||||
// check for forced tags
|
// check for forced tags
|
||||||
for _, machine := range machines {
|
for _, machine := range machines {
|
||||||
if contains(machine.ForcedTags, alias) {
|
if util.StringOrPrefixListContains(machine.ForcedTags, alias) {
|
||||||
machine.IPAddresses.AppendToIPSet(&build)
|
machine.IPAddresses.AppendToIPSet(&build)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -783,7 +779,7 @@ func (pol *ACLPolicy) getIPsFromTag(
|
||||||
machines := filterMachinesByUser(machines, user)
|
machines := filterMachinesByUser(machines, user)
|
||||||
for _, machine := range machines {
|
for _, machine := range machines {
|
||||||
hi := machine.GetHostInfo()
|
hi := machine.GetHostInfo()
|
||||||
if contains(hi.RequestTags, alias) {
|
if util.StringOrPrefixListContains(hi.RequestTags, alias) {
|
||||||
machine.IPAddresses.AppendToIPSet(&build)
|
machine.IPAddresses.AppendToIPSet(&build)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -238,13 +238,13 @@ func (s *Suite) TestInvalidAction(c *check.C) {
|
||||||
func (s *Suite) TestSshRules(c *check.C) {
|
func (s *Suite) TestSshRules(c *check.C) {
|
||||||
envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1")
|
envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1")
|
||||||
|
|
||||||
user, err := app.CreateUser("user1")
|
user, err := app.db.CreateUser("user1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("user1", "testmachine")
|
_, err = app.db.GetMachine("user1", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
hostInfo := tailcfg.Hostinfo{
|
hostInfo := tailcfg.Hostinfo{
|
||||||
OS: "centos",
|
OS: "centos",
|
||||||
|
@ -264,7 +264,7 @@ func (s *Suite) TestSshRules(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo),
|
HostInfo: HostInfo(hostInfo),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
app.aclPolicy = &ACLPolicy{
|
app.aclPolicy = &ACLPolicy{
|
||||||
Groups: Groups{
|
Groups: Groups{
|
||||||
|
@ -348,13 +348,13 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
|
||||||
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
||||||
// the tag is matched in the Sources section.
|
// the tag is matched in the Sources section.
|
||||||
func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
|
func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
|
||||||
user, err := app.CreateUser("user1")
|
user, err := app.db.CreateUser("user1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("user1", "testmachine")
|
_, err = app.db.GetMachine("user1", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
hostInfo := tailcfg.Hostinfo{
|
hostInfo := tailcfg.Hostinfo{
|
||||||
OS: "centos",
|
OS: "centos",
|
||||||
|
@ -374,7 +374,7 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo),
|
HostInfo: HostInfo(hostInfo),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
app.aclPolicy = &ACLPolicy{
|
app.aclPolicy = &ACLPolicy{
|
||||||
Groups: Groups{"group:test": []string{"user1", "user2"}},
|
Groups: Groups{"group:test": []string{"user1", "user2"}},
|
||||||
|
@ -398,13 +398,13 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
|
||||||
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
||||||
// the tag is matched in the Destinations section.
|
// the tag is matched in the Destinations section.
|
||||||
func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
|
func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
|
||||||
user, err := app.CreateUser("user1")
|
user, err := app.db.CreateUser("user1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("user1", "testmachine")
|
_, err = app.db.GetMachine("user1", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
hostInfo := tailcfg.Hostinfo{
|
hostInfo := tailcfg.Hostinfo{
|
||||||
OS: "centos",
|
OS: "centos",
|
||||||
|
@ -424,7 +424,7 @@ func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo),
|
HostInfo: HostInfo(hostInfo),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
app.aclPolicy = &ACLPolicy{
|
app.aclPolicy = &ACLPolicy{
|
||||||
Groups: Groups{"group:test": []string{"user1", "user2"}},
|
Groups: Groups{"group:test": []string{"user1", "user2"}},
|
||||||
|
@ -448,13 +448,13 @@ func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
|
||||||
// tag on a host that isn't owned by a tag owners. So the user
|
// tag on a host that isn't owned by a tag owners. So the user
|
||||||
// of the host should be valid.
|
// of the host should be valid.
|
||||||
func (s *Suite) TestInvalidTagValidUser(c *check.C) {
|
func (s *Suite) TestInvalidTagValidUser(c *check.C) {
|
||||||
user, err := app.CreateUser("user1")
|
user, err := app.db.CreateUser("user1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("user1", "testmachine")
|
_, err = app.db.GetMachine("user1", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
hostInfo := tailcfg.Hostinfo{
|
hostInfo := tailcfg.Hostinfo{
|
||||||
OS: "centos",
|
OS: "centos",
|
||||||
|
@ -474,7 +474,7 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo),
|
HostInfo: HostInfo(hostInfo),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
app.aclPolicy = &ACLPolicy{
|
app.aclPolicy = &ACLPolicy{
|
||||||
TagOwners: TagOwners{"tag:test": []string{"user1"}},
|
TagOwners: TagOwners{"tag:test": []string{"user1"}},
|
||||||
|
@ -497,13 +497,13 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) {
|
||||||
// an ACL rule is matching the tag to a user. It should not be valid since the
|
// an ACL rule is matching the tag to a user. It should not be valid since the
|
||||||
// host should be tied to the tag now.
|
// host should be tied to the tag now.
|
||||||
func (s *Suite) TestValidTagInvalidUser(c *check.C) {
|
func (s *Suite) TestValidTagInvalidUser(c *check.C) {
|
||||||
user, err := app.CreateUser("user1")
|
user, err := app.db.CreateUser("user1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("user1", "webserver")
|
_, err = app.db.GetMachine("user1", "webserver")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
hostInfo := tailcfg.Hostinfo{
|
hostInfo := tailcfg.Hostinfo{
|
||||||
OS: "centos",
|
OS: "centos",
|
||||||
|
@ -523,8 +523,8 @@ func (s *Suite) TestValidTagInvalidUser(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo),
|
HostInfo: HostInfo(hostInfo),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
_, err = app.GetMachine("user1", "user")
|
_, err = app.db.GetMachine("user1", "user")
|
||||||
hostInfo2 := tailcfg.Hostinfo{
|
hostInfo2 := tailcfg.Hostinfo{
|
||||||
OS: "debian",
|
OS: "debian",
|
||||||
Hostname: "Hostname",
|
Hostname: "Hostname",
|
||||||
|
@ -542,7 +542,7 @@ func (s *Suite) TestValidTagInvalidUser(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo2),
|
HostInfo: HostInfo(hostInfo2),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
app.aclPolicy = &ACLPolicy{
|
app.aclPolicy = &ACLPolicy{
|
||||||
TagOwners: TagOwners{"tag:webapp": []string{"user1"}},
|
TagOwners: TagOwners{"tag:webapp": []string{"user1"}},
|
||||||
|
@ -694,8 +694,7 @@ func (s *Suite) TestPortWildcard(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestPortWildcardYAML(c *check.C) {
|
func (s *Suite) TestPortWildcardYAML(c *check.C) {
|
||||||
acl := []byte(`
|
acl := []byte(`---
|
||||||
---
|
|
||||||
hosts:
|
hosts:
|
||||||
host-1: 100.100.100.100/32
|
host-1: 100.100.100.100/32
|
||||||
subnet-1: 100.100.101.100/24
|
subnet-1: 100.100.101.100/24
|
||||||
|
@ -704,8 +703,7 @@ acls:
|
||||||
src:
|
src:
|
||||||
- "*"
|
- "*"
|
||||||
dst:
|
dst:
|
||||||
- host-1:*
|
- host-1:*`)
|
||||||
`)
|
|
||||||
err := app.LoadACLPolicyFromBytes(acl, "yaml")
|
err := app.LoadACLPolicyFromBytes(acl, "yaml")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
@ -722,15 +720,15 @@ acls:
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestPortUser(c *check.C) {
|
func (s *Suite) TestPortUser(c *check.C) {
|
||||||
user, err := app.CreateUser("testuser")
|
user, err := app.db.CreateUser("testuser")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("testuser", "testmachine")
|
_, err = app.db.GetMachine("testuser", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
ips, _ := app.getAvailableIPs()
|
ips, _ := app.db.getAvailableIPs()
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "12345",
|
MachineKey: "12345",
|
||||||
|
@ -742,7 +740,7 @@ func (s *Suite) TestPortUser(c *check.C) {
|
||||||
IPAddresses: ips,
|
IPAddresses: ips,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
acl := []byte(`
|
acl := []byte(`
|
||||||
{
|
{
|
||||||
|
@ -767,7 +765,7 @@ func (s *Suite) TestPortUser(c *check.C) {
|
||||||
err = app.LoadACLPolicyFromBytes(acl, "hujson")
|
err = app.LoadACLPolicyFromBytes(acl, "hujson")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machines, err := app.ListMachines()
|
machines, err := app.db.ListMachines()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
rules, err := app.aclPolicy.generateFilterRules(machines, false)
|
rules, err := app.aclPolicy.generateFilterRules(machines, false)
|
||||||
|
@ -785,15 +783,15 @@ func (s *Suite) TestPortUser(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestPortGroup(c *check.C) {
|
func (s *Suite) TestPortGroup(c *check.C) {
|
||||||
user, err := app.CreateUser("testuser")
|
user, err := app.db.CreateUser("testuser")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("testuser", "testmachine")
|
_, err = app.db.GetMachine("testuser", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
ips, _ := app.getAvailableIPs()
|
ips, _ := app.db.getAvailableIPs()
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
|
@ -805,7 +803,7 @@ func (s *Suite) TestPortGroup(c *check.C) {
|
||||||
IPAddresses: ips,
|
IPAddresses: ips,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
acl := []byte(`
|
acl := []byte(`
|
||||||
{
|
{
|
||||||
|
@ -836,7 +834,7 @@ func (s *Suite) TestPortGroup(c *check.C) {
|
||||||
err = app.LoadACLPolicyFromBytes(acl, "hujson")
|
err = app.LoadACLPolicyFromBytes(acl, "hujson")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machines, err := app.ListMachines()
|
machines, err := app.db.ListMachines()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
rules, err := app.aclPolicy.generateFilterRules(machines, false)
|
rules, err := app.aclPolicy.generateFilterRules(machines, false)
|
||||||
|
|
98
hscontrol/addresses.go
Normal file
98
hscontrol/addresses.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
// Codehere is mostly taken from github.com/tailscale/tailscale
|
||||||
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package hscontrol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"go4.org/netipx"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) getAvailableIPs() (MachineAddresses, error) {
|
||||||
|
var ips MachineAddresses
|
||||||
|
var err error
|
||||||
|
for _, ipPrefix := range hsdb.ipPrefixes {
|
||||||
|
var ip *netip.Addr
|
||||||
|
ip, err = hsdb.getAvailableIP(ipPrefix)
|
||||||
|
if err != nil {
|
||||||
|
return ips, err
|
||||||
|
}
|
||||||
|
ips = append(ips, *ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ips, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
|
||||||
|
usedIps, err := hsdb.getUsedIPs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ipPrefixNetworkAddress, ipPrefixBroadcastAddress := util.GetIPPrefixEndpoints(ipPrefix)
|
||||||
|
|
||||||
|
// Get the first IP in our prefix
|
||||||
|
ip := ipPrefixNetworkAddress.Next()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if !ipPrefix.Contains(ip) {
|
||||||
|
return nil, ErrCouldNotAllocateIP
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case ip.Compare(ipPrefixBroadcastAddress) == 0:
|
||||||
|
fallthrough
|
||||||
|
case usedIps.Contains(ip):
|
||||||
|
fallthrough
|
||||||
|
case ip == netip.Addr{} || ip.IsLoopback():
|
||||||
|
ip = ip.Next()
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
default:
|
||||||
|
return &ip, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) {
|
||||||
|
// FIXME: This really deserves a better data model,
|
||||||
|
// but this was quick to get running and it should be enough
|
||||||
|
// to begin experimenting with a dual stack tailnet.
|
||||||
|
var addressesSlices []string
|
||||||
|
hsdb.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
|
||||||
|
|
||||||
|
var ips netipx.IPSetBuilder
|
||||||
|
for _, slice := range addressesSlices {
|
||||||
|
var machineAddresses MachineAddresses
|
||||||
|
err := machineAddresses.Scan(slice)
|
||||||
|
if err != nil {
|
||||||
|
return &netipx.IPSet{}, fmt.Errorf(
|
||||||
|
"failed to read ip from database: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range machineAddresses {
|
||||||
|
ips.Add(ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ipSet, err := ips.IPSet()
|
||||||
|
if err != nil {
|
||||||
|
return &netipx.IPSet{}, fmt.Errorf(
|
||||||
|
"failed to build IP Set: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipSet, nil
|
||||||
|
}
|
|
@ -8,7 +8,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||||
ips, err := app.getAvailableIPs()
|
ips, err := app.db.getAvailableIPs()
|
||||||
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
@ -19,16 +19,16 @@ func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetUsedIps(c *check.C) {
|
func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
ips, err := app.getAvailableIPs()
|
ips, err := app.db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
user, err := app.CreateUser("test-ip")
|
user, err := app.db.CreateUser("test-ip")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "testmachine")
|
_, err = app.db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
|
@ -42,9 +42,9 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
IPAddresses: ips,
|
IPAddresses: ips,
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
usedIps, err := app.getUsedIPs()
|
usedIps, err := app.db.getUsedIPs()
|
||||||
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
|
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
|
||||||
c.Assert(usedIps.Contains(expected), check.Equals, true)
|
c.Assert(usedIps.Contains(expected), check.Equals, true)
|
||||||
|
|
||||||
machine1, err := app.GetMachineByID(0)
|
machine1, err := app.db.GetMachineByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
||||||
|
@ -64,19 +64,19 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
user, err := app.CreateUser("test-ip-multi")
|
user, err := app.db.CreateUser("test-ip-multi")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
for index := 1; index <= 350; index++ {
|
for index := 1; index <= 350; index++ {
|
||||||
app.ipAllocationMutex.Lock()
|
app.db.ipAllocationMutex.Lock()
|
||||||
|
|
||||||
ips, err := app.getAvailableIPs()
|
ips, err := app.db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "testmachine")
|
_, err = app.db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
|
@ -90,12 +90,12 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
IPAddresses: ips,
|
IPAddresses: ips,
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
app.ipAllocationMutex.Unlock()
|
app.db.ipAllocationMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
usedIps, err := app.getUsedIPs()
|
usedIps, err := app.db.getUsedIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
expected0 := netip.MustParseAddr("10.27.0.1")
|
expected0 := netip.MustParseAddr("10.27.0.1")
|
||||||
|
@ -117,7 +117,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
c.Assert(usedIps.Contains(expected300), check.Equals, true)
|
c.Assert(usedIps.Contains(expected300), check.Equals, true)
|
||||||
|
|
||||||
// Check that we can read back the IPs
|
// Check that we can read back the IPs
|
||||||
machine1, err := app.GetMachineByID(1)
|
machine1, err := app.db.GetMachineByID(1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
||||||
c.Assert(
|
c.Assert(
|
||||||
|
@ -126,7 +126,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
netip.MustParseAddr("10.27.0.1"),
|
netip.MustParseAddr("10.27.0.1"),
|
||||||
)
|
)
|
||||||
|
|
||||||
machine50, err := app.GetMachineByID(50)
|
machine50, err := app.db.GetMachineByID(50)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(machine50.IPAddresses), check.Equals, 1)
|
c.Assert(len(machine50.IPAddresses), check.Equals, 1)
|
||||||
c.Assert(
|
c.Assert(
|
||||||
|
@ -136,7 +136,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
)
|
)
|
||||||
|
|
||||||
expectedNextIP := netip.MustParseAddr("10.27.1.95")
|
expectedNextIP := netip.MustParseAddr("10.27.1.95")
|
||||||
nextIP, err := app.getAvailableIPs()
|
nextIP, err := app.db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(nextIP), check.Equals, 1)
|
c.Assert(len(nextIP), check.Equals, 1)
|
||||||
|
@ -144,7 +144,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
|
|
||||||
// If we call get Available again, we should receive
|
// If we call get Available again, we should receive
|
||||||
// the same IP, as it has not been reserved.
|
// the same IP, as it has not been reserved.
|
||||||
nextIP2, err := app.getAvailableIPs()
|
nextIP2, err := app.db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(nextIP2), check.Equals, 1)
|
c.Assert(len(nextIP2), check.Equals, 1)
|
||||||
|
@ -152,7 +152,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||||
ips, err := app.getAvailableIPs()
|
ips, err := app.db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
expected := netip.MustParseAddr("10.27.0.1")
|
expected := netip.MustParseAddr("10.27.0.1")
|
||||||
|
@ -160,13 +160,13 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||||
c.Assert(len(ips), check.Equals, 1)
|
c.Assert(len(ips), check.Equals, 1)
|
||||||
c.Assert(ips[0].String(), check.Equals, expected.String())
|
c.Assert(ips[0].String(), check.Equals, expected.String())
|
||||||
|
|
||||||
user, err := app.CreateUser("test-ip")
|
user, err := app.db.CreateUser("test-ip")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "testmachine")
|
_, err = app.db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
|
@ -179,23 +179,11 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
ips2, err := app.getAvailableIPs()
|
ips2, err := app.db.getAvailableIPs()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(ips2), check.Equals, 1)
|
c.Assert(len(ips2), check.Equals, 1)
|
||||||
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGenerateRandomStringDNSSafe(c *check.C) {
|
|
||||||
for i := 0; i < 100000; i++ {
|
|
||||||
str, err := GenerateRandomStringDNSSafe(8)
|
|
||||||
if err != nil {
|
|
||||||
c.Error(err)
|
|
||||||
}
|
|
||||||
if len(str) != 8 {
|
|
||||||
c.Error("invalid length", len(str), str)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,11 +3,13 @@ package hscontrol
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
@ -19,9 +21,10 @@ const (
|
||||||
RegisterMethodAuthKey = "authkey"
|
RegisterMethodAuthKey = "authkey"
|
||||||
RegisterMethodOIDC = "oidc"
|
RegisterMethodOIDC = "oidc"
|
||||||
RegisterMethodCLI = "cli"
|
RegisterMethodCLI = "cli"
|
||||||
ErrRegisterMethodCLIDoesNotSupportExpire = Error(
|
|
||||||
"machines registered with CLI does not support expire",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
|
||||||
|
"machines registered with CLI does not support expire",
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *Headscale) HealthHandler(
|
func (h *Headscale) HealthHandler(
|
||||||
|
@ -53,7 +56,7 @@ func (h *Headscale) HealthHandler(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.pingDB(req.Context()); err != nil {
|
if err := h.db.pingDB(req.Context()); err != nil {
|
||||||
respond(err)
|
respond(err)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -95,7 +98,7 @@ func (h *Headscale) RegisterWebAPI(
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
nodeKeyStr, ok := vars["nkey"]
|
nodeKeyStr, ok := vars["nkey"]
|
||||||
|
|
||||||
if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
|
if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
|
||||||
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
|
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
@ -116,7 +119,7 @@ func (h *Headscale) RegisterWebAPI(
|
||||||
// the template and log an error.
|
// the template and log an error.
|
||||||
var nodeKey key.NodePublic
|
var nodeKey key.NodePublic
|
||||||
err := nodeKey.UnmarshalText(
|
err := nodeKey.UnmarshalText(
|
||||||
[]byte(NodePublicKeyEnsurePrefix(nodeKeyStr)),
|
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)),
|
||||||
)
|
)
|
||||||
|
|
||||||
if !ok || nodeKeyStr == "" || err != nil {
|
if !ok || nodeKeyStr == "" || err != nil {
|
||||||
|
|
|
@ -3,6 +3,7 @@ package hscontrol
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
@ -15,7 +16,7 @@ func (h *Headscale) generateMapResponse(
|
||||||
Str("func", "generateMapResponse").
|
Str("func", "generateMapResponse").
|
||||||
Str("machine", mapRequest.Hostinfo.Hostname).
|
Str("machine", mapRequest.Hostinfo.Hostname).
|
||||||
Msg("Creating Map response")
|
Msg("Creating Map response")
|
||||||
node, err := h.toNode(*machine, h.cfg.BaseDomain, h.cfg.DNSConfig)
|
node, err := h.db.toNode(*machine, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -26,7 +27,7 @@ func (h *Headscale) generateMapResponse(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := h.getValidPeers(machine)
|
peers, err := h.db.getValidPeers(h.aclPolicy, h.aclRules, machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -37,9 +38,9 @@ func (h *Headscale) generateMapResponse(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles := h.getMapResponseUserProfiles(*machine, peers)
|
profiles := h.db.getMapResponseUserProfiles(*machine, peers)
|
||||||
|
|
||||||
nodePeers, err := h.toNodes(peers, h.cfg.BaseDomain, h.cfg.DNSConfig)
|
nodePeers, err := h.db.toNodes(peers, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -107,7 +108,7 @@ func (h *Headscale) generateMapResponse(
|
||||||
Str("func", "generateMapResponse").
|
Str("func", "generateMapResponse").
|
||||||
Str("machine", mapRequest.Hostinfo.Hostname).
|
Str("machine", mapRequest.Hostinfo.Hostname).
|
||||||
// Interface("payload", resp).
|
// Interface("payload", resp).
|
||||||
Msgf("Generated map response: %s", tailMapResponseToString(resp))
|
Msgf("Generated map response: %s", util.TailMapResponseToString(resp))
|
||||||
|
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
package hscontrol
|
package hscontrol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
)
|
)
|
||||||
|
@ -13,10 +15,10 @@ import (
|
||||||
const (
|
const (
|
||||||
apiPrefixLength = 7
|
apiPrefixLength = 7
|
||||||
apiKeyLength = 32
|
apiKeyLength = 32
|
||||||
|
|
||||||
ErrAPIKeyFailedToParse = Error("Failed to parse ApiKey")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
|
||||||
|
|
||||||
// APIKey describes the datamodel for API keys used to remotely authenticate with
|
// APIKey describes the datamodel for API keys used to remotely authenticate with
|
||||||
// headscale.
|
// headscale.
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
|
@ -30,15 +32,15 @@ type APIKey struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAPIKey creates a new ApiKey in a user, and returns it.
|
// CreateAPIKey creates a new ApiKey in a user, and returns it.
|
||||||
func (h *Headscale) CreateAPIKey(
|
func (hsdb *HSDatabase) CreateAPIKey(
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
) (string, *APIKey, error) {
|
) (string, *APIKey, error) {
|
||||||
prefix, err := GenerateRandomStringURLSafe(apiPrefixLength)
|
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
toBeHashed, err := GenerateRandomStringURLSafe(apiKeyLength)
|
toBeHashed, err := util.GenerateRandomStringURLSafe(apiKeyLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
@ -57,7 +59,7 @@ func (h *Headscale) CreateAPIKey(
|
||||||
Expiration: expiration,
|
Expiration: expiration,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Save(&key).Error; err != nil {
|
if err := hsdb.db.Save(&key).Error; err != nil {
|
||||||
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
|
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,9 +67,9 @@ func (h *Headscale) CreateAPIKey(
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAPIKeys returns the list of ApiKeys for a user.
|
// ListAPIKeys returns the list of ApiKeys for a user.
|
||||||
func (h *Headscale) ListAPIKeys() ([]APIKey, error) {
|
func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) {
|
||||||
keys := []APIKey{}
|
keys := []APIKey{}
|
||||||
if err := h.db.Find(&keys).Error; err != nil {
|
if err := hsdb.db.Find(&keys).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,9 +77,9 @@ func (h *Headscale) ListAPIKeys() ([]APIKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAPIKey returns a ApiKey for a given key.
|
// GetAPIKey returns a ApiKey for a given key.
|
||||||
func (h *Headscale) GetAPIKey(prefix string) (*APIKey, error) {
|
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) {
|
||||||
key := APIKey{}
|
key := APIKey{}
|
||||||
if result := h.db.First(&key, "prefix = ?", prefix); result.Error != nil {
|
if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,9 +87,9 @@ func (h *Headscale) GetAPIKey(prefix string) (*APIKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAPIKeyByID returns a ApiKey for a given id.
|
// GetAPIKeyByID returns a ApiKey for a given id.
|
||||||
func (h *Headscale) GetAPIKeyByID(id uint64) (*APIKey, error) {
|
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) {
|
||||||
key := APIKey{}
|
key := APIKey{}
|
||||||
if result := h.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil {
|
if result := hsdb.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,8 +98,8 @@ func (h *Headscale) GetAPIKeyByID(id uint64) (*APIKey, error) {
|
||||||
|
|
||||||
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
|
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
|
||||||
// does not exist.
|
// does not exist.
|
||||||
func (h *Headscale) DestroyAPIKey(key APIKey) error {
|
func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error {
|
||||||
if result := h.db.Unscoped().Delete(key); result.Error != nil {
|
if result := hsdb.db.Unscoped().Delete(key); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,21 +107,21 @@ func (h *Headscale) DestroyAPIKey(key APIKey) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpireAPIKey marks a ApiKey as expired.
|
// ExpireAPIKey marks a ApiKey as expired.
|
||||||
func (h *Headscale) ExpireAPIKey(key *APIKey) error {
|
func (hsdb *HSDatabase) ExpireAPIKey(key *APIKey) error {
|
||||||
if err := h.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) ValidateAPIKey(keyStr string) (bool, error) {
|
func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
|
||||||
prefix, hash, found := strings.Cut(keyStr, ".")
|
prefix, hash, found := strings.Cut(keyStr, ".")
|
||||||
if !found {
|
if !found {
|
||||||
return false, ErrAPIKeyFailedToParse
|
return false, ErrAPIKeyFailedToParse
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := h.GetAPIKey(prefix)
|
key, err := hsdb.GetAPIKey(prefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to validate api key: %w", err)
|
return false, fmt.Errorf("failed to validate api key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Suite) TestCreateAPIKey(c *check.C) {
|
func (*Suite) TestCreateAPIKey(c *check.C) {
|
||||||
apiKeyStr, apiKey, err := app.CreateAPIKey(nil)
|
apiKeyStr, apiKey, err := app.db.CreateAPIKey(nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey, check.NotNil)
|
c.Assert(apiKey, check.NotNil)
|
||||||
|
|
||||||
|
@ -16,74 +16,74 @@ func (*Suite) TestCreateAPIKey(c *check.C) {
|
||||||
c.Assert(apiKey.Hash, check.NotNil)
|
c.Assert(apiKey.Hash, check.NotNil)
|
||||||
c.Assert(apiKeyStr, check.Not(check.Equals), "")
|
c.Assert(apiKeyStr, check.Not(check.Equals), "")
|
||||||
|
|
||||||
_, err = app.ListAPIKeys()
|
_, err = app.db.ListAPIKeys()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
keys, err := app.ListAPIKeys()
|
keys, err := app.db.ListAPIKeys()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(keys), check.Equals, 1)
|
c.Assert(len(keys), check.Equals, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
|
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
|
||||||
key, err := app.GetAPIKey("does-not-exist")
|
key, err := app.db.GetAPIKey("does-not-exist")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
c.Assert(key, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestValidateAPIKeyOk(c *check.C) {
|
func (*Suite) TestValidateAPIKeyOk(c *check.C) {
|
||||||
nowPlus2 := time.Now().Add(2 * time.Hour)
|
nowPlus2 := time.Now().Add(2 * time.Hour)
|
||||||
apiKeyStr, apiKey, err := app.CreateAPIKey(&nowPlus2)
|
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey, check.NotNil)
|
c.Assert(apiKey, check.NotNil)
|
||||||
|
|
||||||
valid, err := app.ValidateAPIKey(apiKeyStr)
|
valid, err := app.db.ValidateAPIKey(apiKeyStr)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(valid, check.Equals, true)
|
c.Assert(valid, check.Equals, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
||||||
nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour)
|
nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour)
|
||||||
apiKeyStr, apiKey, err := app.CreateAPIKey(&nowMinus2)
|
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowMinus2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey, check.NotNil)
|
c.Assert(apiKey, check.NotNil)
|
||||||
|
|
||||||
valid, err := app.ValidateAPIKey(apiKeyStr)
|
valid, err := app.db.ValidateAPIKey(apiKeyStr)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(valid, check.Equals, false)
|
c.Assert(valid, check.Equals, false)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
apiKeyStrNow, apiKey, err := app.CreateAPIKey(&now)
|
apiKeyStrNow, apiKey, err := app.db.CreateAPIKey(&now)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey, check.NotNil)
|
c.Assert(apiKey, check.NotNil)
|
||||||
|
|
||||||
validNow, err := app.ValidateAPIKey(apiKeyStrNow)
|
validNow, err := app.db.ValidateAPIKey(apiKeyStrNow)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(validNow, check.Equals, false)
|
c.Assert(validNow, check.Equals, false)
|
||||||
|
|
||||||
validSilly, err := app.ValidateAPIKey("nota.validkey")
|
validSilly, err := app.db.ValidateAPIKey("nota.validkey")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
c.Assert(validSilly, check.Equals, false)
|
c.Assert(validSilly, check.Equals, false)
|
||||||
|
|
||||||
validWithErr, err := app.ValidateAPIKey("produceerrorkey")
|
validWithErr, err := app.db.ValidateAPIKey("produceerrorkey")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
c.Assert(validWithErr, check.Equals, false)
|
c.Assert(validWithErr, check.Equals, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpireAPIKey(c *check.C) {
|
func (*Suite) TestExpireAPIKey(c *check.C) {
|
||||||
nowPlus2 := time.Now().Add(2 * time.Hour)
|
nowPlus2 := time.Now().Add(2 * time.Hour)
|
||||||
apiKeyStr, apiKey, err := app.CreateAPIKey(&nowPlus2)
|
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey, check.NotNil)
|
c.Assert(apiKey, check.NotNil)
|
||||||
|
|
||||||
valid, err := app.ValidateAPIKey(apiKeyStr)
|
valid, err := app.db.ValidateAPIKey(apiKeyStr)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(valid, check.Equals, true)
|
c.Assert(valid, check.Equals, true)
|
||||||
|
|
||||||
err = app.ExpireAPIKey(apiKey)
|
err = app.db.ExpireAPIKey(apiKey)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(apiKey.Expiration, check.NotNil)
|
c.Assert(apiKey.Expiration, check.NotNil)
|
||||||
|
|
||||||
notValid, err := app.ValidateAPIKey(apiKeyStr)
|
notValid, err := app.db.ValidateAPIKey(apiKeyStr)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(notValid, check.Equals, false)
|
c.Assert(notValid, check.Equals, false)
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||||
"github.com/juanfont/headscale"
|
"github.com/juanfont/headscale"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
@ -41,24 +42,21 @@ import (
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/reflection"
|
"google.golang.org/grpc/reflection"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"gorm.io/gorm"
|
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
errSTUNAddressNotSet = Error("STUN address not set")
|
errSTUNAddressNotSet = errors.New("STUN address not set")
|
||||||
errUnsupportedDatabase = Error("unsupported DB")
|
errUnsupportedDatabase = errors.New("unsupported DB")
|
||||||
errUnsupportedLetsEncryptChallengeType = Error(
|
errUnsupportedLetsEncryptChallengeType = errors.New(
|
||||||
"unknown value for Lets Encrypt challenge type",
|
"unknown value for Lets Encrypt challenge type",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AuthPrefix = "Bearer "
|
AuthPrefix = "Bearer "
|
||||||
Postgres = "postgres"
|
|
||||||
Sqlite = "sqlite3"
|
|
||||||
updateInterval = 5000
|
updateInterval = 5000
|
||||||
HTTPReadTimeout = 30 * time.Second
|
HTTPReadTimeout = 30 * time.Second
|
||||||
HTTPShutdownTimeout = 3 * time.Second
|
HTTPShutdownTimeout = 3 * time.Second
|
||||||
|
@ -75,7 +73,7 @@ const (
|
||||||
// Headscale represents the base app of the service.
|
// Headscale represents the base app of the service.
|
||||||
type Headscale struct {
|
type Headscale struct {
|
||||||
cfg *Config
|
cfg *Config
|
||||||
db *gorm.DB
|
db *HSDatabase
|
||||||
dbString string
|
dbString string
|
||||||
dbType string
|
dbType string
|
||||||
dbDebug bool
|
dbDebug bool
|
||||||
|
@ -96,10 +94,11 @@ type Headscale struct {
|
||||||
|
|
||||||
registrationCache *cache.Cache
|
registrationCache *cache.Cache
|
||||||
|
|
||||||
ipAllocationMutex sync.Mutex
|
|
||||||
|
|
||||||
shutdownChan chan struct{}
|
shutdownChan chan struct{}
|
||||||
pollNetMapStreamWG sync.WaitGroup
|
pollNetMapStreamWG sync.WaitGroup
|
||||||
|
|
||||||
|
stateUpdateChan chan struct{}
|
||||||
|
cancelStateUpdateChan chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHeadscale(cfg *Config) (*Headscale, error) {
|
func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||||
|
@ -164,13 +163,27 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||||
registrationCache: registrationCache,
|
registrationCache: registrationCache,
|
||||||
pollNetMapStreamWG: sync.WaitGroup{},
|
pollNetMapStreamWG: sync.WaitGroup{},
|
||||||
lastStateChange: xsync.NewMapOf[time.Time](),
|
lastStateChange: xsync.NewMapOf[time.Time](),
|
||||||
|
|
||||||
|
stateUpdateChan: make(chan struct{}),
|
||||||
|
cancelStateUpdateChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = app.initDB()
|
go app.watchStateChannel()
|
||||||
|
|
||||||
|
db, err := NewHeadscaleDatabase(
|
||||||
|
cfg.DBtype,
|
||||||
|
dbString,
|
||||||
|
cfg.OIDC.StripEmaildomain,
|
||||||
|
app.dbDebug,
|
||||||
|
app.stateUpdateChan,
|
||||||
|
cfg.IPPrefixes,
|
||||||
|
cfg.BaseDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
app.db = db
|
||||||
|
|
||||||
if cfg.OIDC.Issuer != "" {
|
if cfg.OIDC.Issuer != "" {
|
||||||
err = app.initOIDC()
|
err = app.initOIDC()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -231,7 +244,7 @@ func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
|
||||||
func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
|
func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
|
||||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
err := h.handlePrimarySubnetFailover()
|
err := h.db.handlePrimarySubnetFailover()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("failed to handle primary subnet failover")
|
log.Error().Err(err).Msg("failed to handle primary subnet failover")
|
||||||
}
|
}
|
||||||
|
@ -239,7 +252,7 @@ func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) expireEphemeralNodesWorker() {
|
func (h *Headscale) expireEphemeralNodesWorker() {
|
||||||
users, err := h.ListUsers()
|
users, err := h.db.ListUsers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error listing users")
|
log.Error().Err(err).Msg("Error listing users")
|
||||||
|
|
||||||
|
@ -247,7 +260,7 @@ func (h *Headscale) expireEphemeralNodesWorker() {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
machines, err := h.ListMachinesByUser(user.Name)
|
machines, err := h.db.ListMachinesByUser(user.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -267,7 +280,7 @@ func (h *Headscale) expireEphemeralNodesWorker() {
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
Msg("Ephemeral client removed from database")
|
Msg("Ephemeral client removed from database")
|
||||||
|
|
||||||
err = h.db.Unscoped().Delete(machine).Error
|
err = h.db.db.Unscoped().Delete(machine).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -284,7 +297,7 @@ func (h *Headscale) expireEphemeralNodesWorker() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) expireExpiredMachinesWorker() {
|
func (h *Headscale) expireExpiredMachinesWorker() {
|
||||||
users, err := h.ListUsers()
|
users, err := h.db.ListUsers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error listing users")
|
log.Error().Err(err).Msg("Error listing users")
|
||||||
|
|
||||||
|
@ -292,7 +305,7 @@ func (h *Headscale) expireExpiredMachinesWorker() {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
machines, err := h.ListMachinesByUser(user.Name)
|
machines, err := h.db.ListMachinesByUser(user.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -308,7 +321,7 @@ func (h *Headscale) expireExpiredMachinesWorker() {
|
||||||
machine.Expiry.After(h.getLastStateChange(user)) {
|
machine.Expiry.After(h.getLastStateChange(user)) {
|
||||||
expiredFound = true
|
expiredFound = true
|
||||||
|
|
||||||
err := h.ExpireMachine(&machines[index])
|
err := h.db.ExpireMachine(&machines[index])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -387,7 +400,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
valid, err := h.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
|
valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -438,7 +451,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
|
valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -597,7 +610,7 @@ func (h *Headscale) Serve() error {
|
||||||
h.cfg.UnixSocket,
|
h.cfg.UnixSocket,
|
||||||
[]grpc.DialOption{
|
[]grpc.DialOption{
|
||||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
grpc.WithContextDialer(GrpcSocketDialer),
|
grpc.WithContextDialer(util.GrpcSocketDialer),
|
||||||
}...,
|
}...,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -760,7 +773,7 @@ func (h *Headscale) Serve() error {
|
||||||
// TODO(kradalby): Reload config on SIGHUP
|
// TODO(kradalby): Reload config on SIGHUP
|
||||||
|
|
||||||
if h.cfg.ACL.PolicyPath != "" {
|
if h.cfg.ACL.PolicyPath != "" {
|
||||||
aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
|
aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
|
||||||
err := h.LoadACLPolicyFromPath(aclPath)
|
err := h.LoadACLPolicyFromPath(aclPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to reload ACL policy")
|
log.Error().Err(err).Msg("Failed to reload ACL policy")
|
||||||
|
@ -778,6 +791,7 @@ func (h *Headscale) Serve() error {
|
||||||
Msg("Received signal to stop, shutting down gracefully")
|
Msg("Received signal to stop, shutting down gracefully")
|
||||||
|
|
||||||
close(h.shutdownChan)
|
close(h.shutdownChan)
|
||||||
|
|
||||||
h.pollNetMapStreamWG.Wait()
|
h.pollNetMapStreamWG.Wait()
|
||||||
|
|
||||||
// Gracefully shut down servers
|
// Gracefully shut down servers
|
||||||
|
@ -806,8 +820,12 @@ func (h *Headscale) Serve() error {
|
||||||
// Stop listening (and unlink the socket if unix type):
|
// Stop listening (and unlink the socket if unix type):
|
||||||
socketListener.Close()
|
socketListener.Close()
|
||||||
|
|
||||||
|
<-h.cancelStateUpdateChan
|
||||||
|
close(h.stateUpdateChan)
|
||||||
|
close(h.cancelStateUpdateChan)
|
||||||
|
|
||||||
// Close db connections
|
// Close db connections
|
||||||
db, err := h.db.DB()
|
db, err := h.db.db.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to get db handle")
|
log.Error().Err(err).Msg("Failed to get db handle")
|
||||||
}
|
}
|
||||||
|
@ -905,12 +923,25 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): baby steps, make this more robust.
|
||||||
|
func (h *Headscale) watchStateChannel() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.stateUpdateChan:
|
||||||
|
h.setLastStateChangeToNow()
|
||||||
|
|
||||||
|
case <-h.cancelStateUpdateChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Headscale) setLastStateChangeToNow() {
|
func (h *Headscale) setLastStateChangeToNow() {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
users, err := h.ListUsers()
|
users, err := h.db.ListUsers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -1002,7 +1033,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
trimmedPrivateKey := strings.TrimSpace(string(privateKey))
|
trimmedPrivateKey := strings.TrimSpace(string(privateKey))
|
||||||
privateKeyEnsurePrefix := PrivateKeyEnsurePrefix(trimmedPrivateKey)
|
privateKeyEnsurePrefix := util.PrivateKeyEnsurePrefix(trimmedPrivateKey)
|
||||||
|
|
||||||
var machineKey key.MachinePrivate
|
var machineKey key.MachinePrivate
|
||||||
if err = machineKey.UnmarshalText([]byte(privateKeyEnsurePrefix)); err != nil {
|
if err = machineKey.UnmarshalText([]byte(privateKeyEnsurePrefix)); err != nil {
|
||||||
|
|
|
@ -42,18 +42,32 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||||
IPPrefixes: []netip.Prefix{
|
IPPrefixes: []netip.Prefix{
|
||||||
netip.MustParsePrefix("10.27.0.0/23"),
|
netip.MustParsePrefix("10.27.0.0/23"),
|
||||||
},
|
},
|
||||||
|
OIDC: OIDCConfig{
|
||||||
|
StripEmaildomain: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): make this use NewHeadscale properly so it doesnt drift
|
||||||
app = Headscale{
|
app = Headscale{
|
||||||
cfg: &cfg,
|
cfg: &cfg,
|
||||||
dbType: "sqlite3",
|
dbType: "sqlite3",
|
||||||
dbString: tmpDir + "/headscale_test.db",
|
dbString: tmpDir + "/headscale_test.db",
|
||||||
|
|
||||||
|
stateUpdateChan: make(chan struct{}),
|
||||||
|
cancelStateUpdateChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
err = app.initDB()
|
|
||||||
if err != nil {
|
go app.watchStateChannel()
|
||||||
c.Fatal(err)
|
|
||||||
}
|
db, err := NewHeadscaleDatabase(
|
||||||
db, err := app.openDB()
|
app.dbType,
|
||||||
|
app.dbString,
|
||||||
|
cfg.OIDC.StripEmaildomain,
|
||||||
|
false,
|
||||||
|
app.stateUpdateChan,
|
||||||
|
cfg.IPPrefixes,
|
||||||
|
"",
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/prometheus/common/model"
|
"github.com/prometheus/common/model"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -271,15 +272,15 @@ func GetTLSConfig() TLSConfig {
|
||||||
LetsEncrypt: LetsEncryptConfig{
|
LetsEncrypt: LetsEncryptConfig{
|
||||||
Hostname: viper.GetString("tls_letsencrypt_hostname"),
|
Hostname: viper.GetString("tls_letsencrypt_hostname"),
|
||||||
Listen: viper.GetString("tls_letsencrypt_listen"),
|
Listen: viper.GetString("tls_letsencrypt_listen"),
|
||||||
CacheDir: AbsolutePathFromConfigPath(
|
CacheDir: util.AbsolutePathFromConfigPath(
|
||||||
viper.GetString("tls_letsencrypt_cache_dir"),
|
viper.GetString("tls_letsencrypt_cache_dir"),
|
||||||
),
|
),
|
||||||
ChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
|
ChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
|
||||||
},
|
},
|
||||||
CertPath: AbsolutePathFromConfigPath(
|
CertPath: util.AbsolutePathFromConfigPath(
|
||||||
viper.GetString("tls_cert_path"),
|
viper.GetString("tls_cert_path"),
|
||||||
),
|
),
|
||||||
KeyPath: AbsolutePathFromConfigPath(
|
KeyPath: util.AbsolutePathFromConfigPath(
|
||||||
viper.GetString("tls_key_path"),
|
viper.GetString("tls_key_path"),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -585,10 +586,10 @@ func GetHeadscaleConfig() (*Config, error) {
|
||||||
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
|
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
|
||||||
|
|
||||||
IPPrefixes: prefixes,
|
IPPrefixes: prefixes,
|
||||||
PrivateKeyPath: AbsolutePathFromConfigPath(
|
PrivateKeyPath: util.AbsolutePathFromConfigPath(
|
||||||
viper.GetString("private_key_path"),
|
viper.GetString("private_key_path"),
|
||||||
),
|
),
|
||||||
NoisePrivateKeyPath: AbsolutePathFromConfigPath(
|
NoisePrivateKeyPath: util.AbsolutePathFromConfigPath(
|
||||||
viper.GetString("noise.private_key_path"),
|
viper.GetString("noise.private_key_path"),
|
||||||
),
|
),
|
||||||
BaseDomain: baseDomain,
|
BaseDomain: baseDomain,
|
||||||
|
@ -604,7 +605,7 @@ func GetHeadscaleConfig() (*Config, error) {
|
||||||
),
|
),
|
||||||
|
|
||||||
DBtype: viper.GetString("db_type"),
|
DBtype: viper.GetString("db_type"),
|
||||||
DBpath: AbsolutePathFromConfigPath(viper.GetString("db_path")),
|
DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")),
|
||||||
DBhost: viper.GetString("db_host"),
|
DBhost: viper.GetString("db_host"),
|
||||||
DBport: viper.GetInt("db_port"),
|
DBport: viper.GetInt("db_port"),
|
||||||
DBname: viper.GetString("db_name"),
|
DBname: viper.GetString("db_name"),
|
||||||
|
@ -620,7 +621,7 @@ func GetHeadscaleConfig() (*Config, error) {
|
||||||
ACMEURL: viper.GetString("acme_url"),
|
ACMEURL: viper.GetString("acme_url"),
|
||||||
|
|
||||||
UnixSocket: viper.GetString("unix_socket"),
|
UnixSocket: viper.GetString("unix_socket"),
|
||||||
UnixSocketPermission: GetFileMode("unix_socket_permission"),
|
UnixSocketPermission: util.GetFileMode("unix_socket_permission"),
|
||||||
|
|
||||||
OIDC: OIDCConfig{
|
OIDC: OIDCConfig{
|
||||||
OnlyStartIfOIDCIsAvailable: viper.GetBool(
|
OnlyStartIfOIDCIsAvailable: viper.GetBool(
|
||||||
|
|
184
hscontrol/db.go
184
hscontrol/db.go
|
@ -7,6 +7,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/glebarez/sqlite"
|
"github.com/glebarez/sqlite"
|
||||||
|
@ -19,55 +20,90 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dbVersion = "1"
|
dbVersion = "1"
|
||||||
|
Postgres = "postgres"
|
||||||
|
Sqlite = "sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
errValueNotFound = Error("not found")
|
var (
|
||||||
ErrCannotParsePrefix = Error("cannot parse prefix")
|
errValueNotFound = errors.New("not found")
|
||||||
|
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||||
|
errDatabaseNotSupported = errors.New("database type not supported")
|
||||||
)
|
)
|
||||||
|
|
||||||
// KV is a key-value store in a psql table. For future use...
|
// KV is a key-value store in a psql table. For future use...
|
||||||
|
// TODO(kradalby): Is this used for anything?
|
||||||
type KV struct {
|
type KV struct {
|
||||||
Key string
|
Key string
|
||||||
Value string
|
Value string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) initDB() error {
|
type HSDatabase struct {
|
||||||
db, err := h.openDB()
|
db *gorm.DB
|
||||||
|
notifyStateChan chan<- struct{}
|
||||||
|
|
||||||
|
ipAllocationMutex sync.Mutex
|
||||||
|
|
||||||
|
ipPrefixes []netip.Prefix
|
||||||
|
baseDomain string
|
||||||
|
stripEmailDomain bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): assemble this struct from toptions or something typed
|
||||||
|
// rather than arguments.
|
||||||
|
func NewHeadscaleDatabase(
|
||||||
|
dbType, connectionAddr string,
|
||||||
|
stripEmailDomain, debug bool,
|
||||||
|
notifyStateChan chan<- struct{},
|
||||||
|
ipPrefixes []netip.Prefix,
|
||||||
|
baseDomain string,
|
||||||
|
) (*HSDatabase, error) {
|
||||||
|
dbConn, err := openDB(dbType, connectionAddr, debug)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
|
||||||
h.db = db
|
|
||||||
|
|
||||||
if h.dbType == Postgres {
|
|
||||||
db.Exec(`create extension if not exists "uuid-ossp";`)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = db.Migrator().RenameTable("namespaces", "users")
|
db := HSDatabase{
|
||||||
|
db: dbConn,
|
||||||
|
notifyStateChan: notifyStateChan,
|
||||||
|
|
||||||
err = db.AutoMigrate(&User{})
|
ipPrefixes: ipPrefixes,
|
||||||
|
baseDomain: baseDomain,
|
||||||
|
stripEmailDomain: stripEmailDomain,
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("database %#v", dbConn)
|
||||||
|
|
||||||
|
if dbType == Postgres {
|
||||||
|
dbConn.Exec(`create extension if not exists "uuid-ossp";`)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = dbConn.Migrator().RenameTable("namespaces", "users")
|
||||||
|
|
||||||
|
err = dbConn.AutoMigrate(User{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = db.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id")
|
_ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id")
|
||||||
_ = db.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id")
|
_ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id")
|
||||||
|
|
||||||
_ = db.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses")
|
_ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses")
|
||||||
_ = db.Migrator().RenameColumn(&Machine{}, "name", "hostname")
|
_ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname")
|
||||||
|
|
||||||
// GivenName is used as the primary source of DNS names, make sure
|
// GivenName is used as the primary source of DNS names, make sure
|
||||||
// the field is populated and normalized if it was not when the
|
// the field is populated and normalized if it was not when the
|
||||||
// machine was registered.
|
// machine was registered.
|
||||||
_ = db.Migrator().RenameColumn(&Machine{}, "nickname", "given_name")
|
_ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name")
|
||||||
|
|
||||||
// If the Machine table has a column for registered,
|
// If the Machine table has a column for registered,
|
||||||
// find all occourences of "false" and drop them. Then
|
// find all occourences of "false" and drop them. Then
|
||||||
// remove the column.
|
// remove the column.
|
||||||
if db.Migrator().HasColumn(&Machine{}, "registered") {
|
if dbConn.Migrator().HasColumn(&Machine{}, "registered") {
|
||||||
log.Info().
|
log.Info().
|
||||||
Msg(`Database has legacy "registered" column in machine, removing...`)
|
Msg(`Database has legacy "registered" column in machine, removing...`)
|
||||||
|
|
||||||
machines := Machines{}
|
machines := Machines{}
|
||||||
if err := h.db.Not("registered").Find(&machines).Error; err != nil {
|
if err := dbConn.Not("registered").Find(&machines).Error; err != nil {
|
||||||
log.Error().Err(err).Msg("Error accessing db")
|
log.Error().Err(err).Msg("Error accessing db")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +112,7 @@ func (h *Headscale) initDB() error {
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
Str("machine_key", machine.MachineKey).
|
Str("machine_key", machine.MachineKey).
|
||||||
Msg("Deleting unregistered machine")
|
Msg("Deleting unregistered machine")
|
||||||
if err := h.db.Delete(&Machine{}, machine.ID).Error; err != nil {
|
if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
|
@ -85,18 +121,18 @@ func (h *Headscale) initDB() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := db.Migrator().DropColumn(&Machine{}, "registered")
|
err := dbConn.Migrator().DropColumn(&Machine{}, "registered")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error dropping registered column")
|
log.Error().Err(err).Msg("Error dropping registered column")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.AutoMigrate(&Route{})
|
err = dbConn.AutoMigrate(&Route{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Migrator().HasColumn(&Machine{}, "enabled_routes") {
|
if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") {
|
||||||
log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...")
|
log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...")
|
||||||
|
|
||||||
type MachineAux struct {
|
type MachineAux struct {
|
||||||
|
@ -105,7 +141,7 @@ func (h *Headscale) initDB() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
machinesAux := []MachineAux{}
|
machinesAux := []MachineAux{}
|
||||||
err := db.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error
|
err := dbConn.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("Error accessing db")
|
log.Fatal().Err(err).Msg("Error accessing db")
|
||||||
}
|
}
|
||||||
|
@ -120,7 +156,7 @@ func (h *Headscale) initDB() error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Preload("Machine").
|
err = dbConn.Preload("Machine").
|
||||||
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).
|
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).
|
||||||
First(&Route{}).
|
First(&Route{}).
|
||||||
Error
|
Error
|
||||||
|
@ -138,7 +174,7 @@ func (h *Headscale) initDB() error {
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Prefix: IPPrefix(prefix),
|
Prefix: IPPrefix(prefix),
|
||||||
}
|
}
|
||||||
if err := h.db.Create(&route).Error; err != nil {
|
if err := dbConn.Create(&route).Error; err != nil {
|
||||||
log.Error().Err(err).Msg("Error creating route")
|
log.Error().Err(err).Msg("Error creating route")
|
||||||
} else {
|
} else {
|
||||||
log.Info().
|
log.Info().
|
||||||
|
@ -149,20 +185,20 @@ func (h *Headscale) initDB() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Migrator().DropColumn(&Machine{}, "enabled_routes")
|
err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error dropping enabled_routes column")
|
log.Error().Err(err).Msg("Error dropping enabled_routes column")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.AutoMigrate(&Machine{})
|
err = dbConn.AutoMigrate(&Machine{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Migrator().HasColumn(&Machine{}, "given_name") {
|
if dbConn.Migrator().HasColumn(&Machine{}, "given_name") {
|
||||||
machines := Machines{}
|
machines := Machines{}
|
||||||
if err := h.db.Find(&machines).Error; err != nil {
|
if err := dbConn.Find(&machines).Error; err != nil {
|
||||||
log.Error().Err(err).Msg("Error accessing db")
|
log.Error().Err(err).Msg("Error accessing db")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,7 +206,7 @@ func (h *Headscale) initDB() error {
|
||||||
if machine.GivenName == "" {
|
if machine.GivenName == "" {
|
||||||
normalizedHostname, err := NormalizeToFQDNRules(
|
normalizedHostname, err := NormalizeToFQDNRules(
|
||||||
machine.Hostname,
|
machine.Hostname,
|
||||||
h.cfg.OIDC.StripEmaildomain,
|
stripEmailDomain,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -180,7 +216,7 @@ func (h *Headscale) initDB() error {
|
||||||
Msg("Failed to normalize machine hostname in DB migration")
|
Msg("Failed to normalize machine hostname in DB migration")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.RenameMachine(&machines[item], normalizedHostname)
|
err = db.RenameMachine(&machines[item], normalizedHostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -192,51 +228,51 @@ func (h *Headscale) initDB() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.AutoMigrate(&KV{})
|
err = dbConn.AutoMigrate(&KV{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.AutoMigrate(&PreAuthKey{})
|
err = dbConn.AutoMigrate(&PreAuthKey{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.AutoMigrate(&PreAuthKeyACLTag{})
|
err = dbConn.AutoMigrate(&PreAuthKeyACLTag{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = db.Migrator().DropTable("shared_machines")
|
_ = dbConn.Migrator().DropTable("shared_machines")
|
||||||
|
|
||||||
err = db.AutoMigrate(&APIKey{})
|
err = dbConn.AutoMigrate(&APIKey{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.setValue("db_version", dbVersion)
|
// TODO(kradalby): is this needed?
|
||||||
|
err = db.setValue("db_version", dbVersion)
|
||||||
|
|
||||||
return err
|
return &db, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) openDB() (*gorm.DB, error) {
|
func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
||||||
var db *gorm.DB
|
log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database")
|
||||||
var err error
|
|
||||||
|
|
||||||
var log logger.Interface
|
var dbLogger logger.Interface
|
||||||
if h.dbDebug {
|
if debug {
|
||||||
log = logger.Default
|
dbLogger = logger.Default
|
||||||
} else {
|
} else {
|
||||||
log = logger.Default.LogMode(logger.Silent)
|
dbLogger = logger.Default.LogMode(logger.Silent)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch h.dbType {
|
switch dbType {
|
||||||
case Sqlite:
|
case Sqlite:
|
||||||
db, err = gorm.Open(
|
db, err := gorm.Open(
|
||||||
sqlite.Open(h.dbString+"?_synchronous=1&_journal_mode=WAL"),
|
sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"),
|
||||||
&gorm.Config{
|
&gorm.Config{
|
||||||
DisableForeignKeyConstraintWhenMigrating: true,
|
DisableForeignKeyConstraintWhenMigrating: true,
|
||||||
Logger: log,
|
Logger: dbLogger,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -250,24 +286,30 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
|
||||||
sqlDB.SetMaxOpenConns(1)
|
sqlDB.SetMaxOpenConns(1)
|
||||||
sqlDB.SetConnMaxIdleTime(time.Hour)
|
sqlDB.SetConnMaxIdleTime(time.Hour)
|
||||||
|
|
||||||
|
return db, err
|
||||||
|
|
||||||
case Postgres:
|
case Postgres:
|
||||||
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{
|
return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{
|
||||||
DisableForeignKeyConstraintWhenMigrating: true,
|
DisableForeignKeyConstraintWhenMigrating: true,
|
||||||
Logger: log,
|
Logger: dbLogger,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
return nil, fmt.Errorf(
|
||||||
return nil, err
|
"database of type %s is not supported: %w",
|
||||||
|
dbType,
|
||||||
|
errDatabaseNotSupported,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return db, nil
|
func (hsdb *HSDatabase) notifyStateChange() {
|
||||||
|
hsdb.notifyStateChan <- struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getValue returns the value for the given key in KV.
|
// getValue returns the value for the given key in KV.
|
||||||
func (h *Headscale) getValue(key string) (string, error) {
|
func (hsdb *HSDatabase) getValue(key string) (string, error) {
|
||||||
var row KV
|
var row KV
|
||||||
if result := h.db.First(&row, "key = ?", key); errors.Is(
|
if result := hsdb.db.First(&row, "key = ?", key); errors.Is(
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
) {
|
) {
|
||||||
|
@ -278,34 +320,34 @@ func (h *Headscale) getValue(key string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// setValue sets value for the given key in KV.
|
// setValue sets value for the given key in KV.
|
||||||
func (h *Headscale) setValue(key string, value string) error {
|
func (hsdb *HSDatabase) setValue(key string, value string) error {
|
||||||
keyValue := KV{
|
keyValue := KV{
|
||||||
Key: key,
|
Key: key,
|
||||||
Value: value,
|
Value: value,
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := h.getValue(key); err == nil {
|
if _, err := hsdb.getValue(key); err == nil {
|
||||||
h.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
|
hsdb.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Create(keyValue).Error; err != nil {
|
if err := hsdb.db.Create(keyValue).Error; err != nil {
|
||||||
return fmt.Errorf("failed to create key value pair in the database: %w", err)
|
return fmt.Errorf("failed to create key value pair in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) pingDB(ctx context.Context) error {
|
func (hsdb *HSDatabase) pingDB(ctx context.Context) error {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
db, err := h.db.DB()
|
sqlDB, err := hsdb.db.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.PingContext(ctx)
|
return sqlDB.PingContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is a "wrapper" type around tailscales
|
// This is a "wrapper" type around tailscales
|
||||||
|
|
|
@ -112,16 +112,16 @@ func (s *Suite) TestMagicDNSRootDomainsIPv6SingleMultiple(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
userShared1, err := app.CreateUser("shared1")
|
userShared1, err := app.db.CreateUser("shared1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
userShared2, err := app.CreateUser("shared2")
|
userShared2, err := app.db.CreateUser("shared2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
userShared3, err := app.CreateUser("shared3")
|
userShared3, err := app.db.CreateUser("shared3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
preAuthKeyInShared1, err := app.CreatePreAuthKey(
|
preAuthKeyInShared1, err := app.db.CreatePreAuthKey(
|
||||||
userShared1.Name,
|
userShared1.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -130,7 +130,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
preAuthKeyInShared2, err := app.CreatePreAuthKey(
|
preAuthKeyInShared2, err := app.db.CreatePreAuthKey(
|
||||||
userShared2.Name,
|
userShared2.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -139,7 +139,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
preAuthKeyInShared3, err := app.CreatePreAuthKey(
|
preAuthKeyInShared3, err := app.db.CreatePreAuthKey(
|
||||||
userShared3.Name,
|
userShared3.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -148,7 +148,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
PreAuthKey2InShared1, err := app.CreatePreAuthKey(
|
PreAuthKey2InShared1, err := app.db.CreatePreAuthKey(
|
||||||
userShared1.Name,
|
userShared1.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -157,7 +157,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machineInShared1 := &Machine{
|
machineInShared1 := &Machine{
|
||||||
|
@ -172,9 +172,9 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machineInShared1)
|
app.db.db.Save(machineInShared1)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machineInShared2 := &Machine{
|
machineInShared2 := &Machine{
|
||||||
|
@ -189,9 +189,9 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machineInShared2)
|
app.db.db.Save(machineInShared2)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machineInShared3 := &Machine{
|
machineInShared3 := &Machine{
|
||||||
|
@ -206,9 +206,9 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machineInShared3)
|
app.db.db.Save(machineInShared3)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine2InShared1 := &Machine{
|
machine2InShared1 := &Machine{
|
||||||
|
@ -223,7 +223,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||||
AuthKeyID: uint(PreAuthKey2InShared1.ID),
|
AuthKeyID: uint(PreAuthKey2InShared1.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machine2InShared1)
|
app.db.db.Save(machine2InShared1)
|
||||||
|
|
||||||
baseDomain := "foobar.headscale.net"
|
baseDomain := "foobar.headscale.net"
|
||||||
dnsConfigOrig := tailcfg.DNSConfig{
|
dnsConfigOrig := tailcfg.DNSConfig{
|
||||||
|
@ -232,7 +232,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
Proxied: true,
|
Proxied: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
peersOfMachineInShared1, err := app.getPeers(machineInShared1)
|
peersOfMachineInShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
dnsConfig := getMapResponseDNSConfig(
|
dnsConfig := getMapResponseDNSConfig(
|
||||||
|
@ -259,16 +259,16 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
userShared1, err := app.CreateUser("shared1")
|
userShared1, err := app.db.CreateUser("shared1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
userShared2, err := app.CreateUser("shared2")
|
userShared2, err := app.db.CreateUser("shared2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
userShared3, err := app.CreateUser("shared3")
|
userShared3, err := app.db.CreateUser("shared3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
preAuthKeyInShared1, err := app.CreatePreAuthKey(
|
preAuthKeyInShared1, err := app.db.CreatePreAuthKey(
|
||||||
userShared1.Name,
|
userShared1.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -277,7 +277,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
preAuthKeyInShared2, err := app.CreatePreAuthKey(
|
preAuthKeyInShared2, err := app.db.CreatePreAuthKey(
|
||||||
userShared2.Name,
|
userShared2.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -286,7 +286,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
preAuthKeyInShared3, err := app.CreatePreAuthKey(
|
preAuthKeyInShared3, err := app.db.CreatePreAuthKey(
|
||||||
userShared3.Name,
|
userShared3.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -295,7 +295,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
preAuthKey2InShared1, err := app.CreatePreAuthKey(
|
preAuthKey2InShared1, err := app.db.CreatePreAuthKey(
|
||||||
userShared1.Name,
|
userShared1.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -304,7 +304,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machineInShared1 := &Machine{
|
machineInShared1 := &Machine{
|
||||||
|
@ -319,9 +319,9 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machineInShared1)
|
app.db.db.Save(machineInShared1)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machineInShared2 := &Machine{
|
machineInShared2 := &Machine{
|
||||||
|
@ -336,9 +336,9 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machineInShared2)
|
app.db.db.Save(machineInShared2)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machineInShared3 := &Machine{
|
machineInShared3 := &Machine{
|
||||||
|
@ -353,9 +353,9 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machineInShared3)
|
app.db.db.Save(machineInShared3)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine2InShared1 := &Machine{
|
machine2InShared1 := &Machine{
|
||||||
|
@ -370,7 +370,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||||
AuthKeyID: uint(preAuthKey2InShared1.ID),
|
AuthKeyID: uint(preAuthKey2InShared1.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machine2InShared1)
|
app.db.db.Save(machine2InShared1)
|
||||||
|
|
||||||
baseDomain := "foobar.headscale.net"
|
baseDomain := "foobar.headscale.net"
|
||||||
dnsConfigOrig := tailcfg.DNSConfig{
|
dnsConfigOrig := tailcfg.DNSConfig{
|
||||||
|
@ -379,7 +379,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||||
Proxied: false,
|
Proxied: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
peersOfMachine1Shared1, err := app.getPeers(machineInShared1)
|
peersOfMachine1Shared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
dnsConfig := getMapResponseDNSConfig(
|
dnsConfig := getMapResponseDNSConfig(
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
@ -30,7 +31,7 @@ func (api headscaleV1APIServer) GetUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.GetUserRequest,
|
request *v1.GetUserRequest,
|
||||||
) (*v1.GetUserResponse, error) {
|
) (*v1.GetUserResponse, error) {
|
||||||
user, err := api.h.GetUser(request.GetName())
|
user, err := api.h.db.GetUser(request.GetName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -42,7 +43,7 @@ func (api headscaleV1APIServer) CreateUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.CreateUserRequest,
|
request *v1.CreateUserRequest,
|
||||||
) (*v1.CreateUserResponse, error) {
|
) (*v1.CreateUserResponse, error) {
|
||||||
user, err := api.h.CreateUser(request.GetName())
|
user, err := api.h.db.CreateUser(request.GetName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -54,12 +55,12 @@ func (api headscaleV1APIServer) RenameUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.RenameUserRequest,
|
request *v1.RenameUserRequest,
|
||||||
) (*v1.RenameUserResponse, error) {
|
) (*v1.RenameUserResponse, error) {
|
||||||
err := api.h.RenameUser(request.GetOldName(), request.GetNewName())
|
err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := api.h.GetUser(request.GetNewName())
|
user, err := api.h.db.GetUser(request.GetNewName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -71,7 +72,7 @@ func (api headscaleV1APIServer) DeleteUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteUserRequest,
|
request *v1.DeleteUserRequest,
|
||||||
) (*v1.DeleteUserResponse, error) {
|
) (*v1.DeleteUserResponse, error) {
|
||||||
err := api.h.DestroyUser(request.GetName())
|
err := api.h.db.DestroyUser(request.GetName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -83,7 +84,7 @@ func (api headscaleV1APIServer) ListUsers(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ListUsersRequest,
|
request *v1.ListUsersRequest,
|
||||||
) (*v1.ListUsersResponse, error) {
|
) (*v1.ListUsersResponse, error) {
|
||||||
users, err := api.h.ListUsers()
|
users, err := api.h.db.ListUsers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -116,7 +117,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
preAuthKey, err := api.h.CreatePreAuthKey(
|
preAuthKey, err := api.h.db.CreatePreAuthKey(
|
||||||
request.GetUser(),
|
request.GetUser(),
|
||||||
request.GetReusable(),
|
request.GetReusable(),
|
||||||
request.GetEphemeral(),
|
request.GetEphemeral(),
|
||||||
|
@ -134,12 +135,12 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ExpirePreAuthKeyRequest,
|
request *v1.ExpirePreAuthKeyRequest,
|
||||||
) (*v1.ExpirePreAuthKeyResponse, error) {
|
) (*v1.ExpirePreAuthKeyResponse, error) {
|
||||||
preAuthKey, err := api.h.GetPreAuthKey(request.GetUser(), request.Key)
|
preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.ExpirePreAuthKey(preAuthKey)
|
err = api.h.db.ExpirePreAuthKey(preAuthKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -151,7 +152,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ListPreAuthKeysRequest,
|
request *v1.ListPreAuthKeysRequest,
|
||||||
) (*v1.ListPreAuthKeysResponse, error) {
|
) (*v1.ListPreAuthKeysResponse, error) {
|
||||||
preAuthKeys, err := api.h.ListPreAuthKeys(request.GetUser())
|
preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -173,7 +174,8 @@ func (api headscaleV1APIServer) RegisterMachine(
|
||||||
Str("node_key", request.GetKey()).
|
Str("node_key", request.GetKey()).
|
||||||
Msg("Registering machine")
|
Msg("Registering machine")
|
||||||
|
|
||||||
machine, err := api.h.RegisterMachineFromAuthCallback(
|
machine, err := api.h.db.RegisterMachineFromAuthCallback(
|
||||||
|
api.h.registrationCache,
|
||||||
request.GetKey(),
|
request.GetKey(),
|
||||||
request.GetUser(),
|
request.GetUser(),
|
||||||
nil,
|
nil,
|
||||||
|
@ -190,7 +192,7 @@ func (api headscaleV1APIServer) GetMachine(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.GetMachineRequest,
|
request *v1.GetMachineRequest,
|
||||||
) (*v1.GetMachineResponse, error) {
|
) (*v1.GetMachineResponse, error) {
|
||||||
machine, err := api.h.GetMachineByID(request.GetMachineId())
|
machine, err := api.h.db.GetMachineByID(request.GetMachineId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -202,7 +204,7 @@ func (api headscaleV1APIServer) SetTags(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.SetTagsRequest,
|
request *v1.SetTagsRequest,
|
||||||
) (*v1.SetTagsResponse, error) {
|
) (*v1.SetTagsResponse, error) {
|
||||||
machine, err := api.h.GetMachineByID(request.GetMachineId())
|
machine, err := api.h.db.GetMachineByID(request.GetMachineId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -216,7 +218,7 @@ func (api headscaleV1APIServer) SetTags(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.SetTags(machine, request.GetTags())
|
err = api.h.db.SetTags(machine, request.GetTags(), api.h.UpdateACLRules)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &v1.SetTagsResponse{
|
return &v1.SetTagsResponse{
|
||||||
Machine: nil,
|
Machine: nil,
|
||||||
|
@ -248,12 +250,12 @@ func (api headscaleV1APIServer) DeleteMachine(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteMachineRequest,
|
request *v1.DeleteMachineRequest,
|
||||||
) (*v1.DeleteMachineResponse, error) {
|
) (*v1.DeleteMachineResponse, error) {
|
||||||
machine, err := api.h.GetMachineByID(request.GetMachineId())
|
machine, err := api.h.db.GetMachineByID(request.GetMachineId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.DeleteMachine(
|
err = api.h.db.DeleteMachine(
|
||||||
machine,
|
machine,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -267,12 +269,12 @@ func (api headscaleV1APIServer) ExpireMachine(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ExpireMachineRequest,
|
request *v1.ExpireMachineRequest,
|
||||||
) (*v1.ExpireMachineResponse, error) {
|
) (*v1.ExpireMachineResponse, error) {
|
||||||
machine, err := api.h.GetMachineByID(request.GetMachineId())
|
machine, err := api.h.db.GetMachineByID(request.GetMachineId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
api.h.ExpireMachine(
|
api.h.db.ExpireMachine(
|
||||||
machine,
|
machine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -288,12 +290,12 @@ func (api headscaleV1APIServer) RenameMachine(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.RenameMachineRequest,
|
request *v1.RenameMachineRequest,
|
||||||
) (*v1.RenameMachineResponse, error) {
|
) (*v1.RenameMachineResponse, error) {
|
||||||
machine, err := api.h.GetMachineByID(request.GetMachineId())
|
machine, err := api.h.db.GetMachineByID(request.GetMachineId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.RenameMachine(
|
err = api.h.db.RenameMachine(
|
||||||
machine,
|
machine,
|
||||||
request.GetNewName(),
|
request.GetNewName(),
|
||||||
)
|
)
|
||||||
|
@ -314,7 +316,7 @@ func (api headscaleV1APIServer) ListMachines(
|
||||||
request *v1.ListMachinesRequest,
|
request *v1.ListMachinesRequest,
|
||||||
) (*v1.ListMachinesResponse, error) {
|
) (*v1.ListMachinesResponse, error) {
|
||||||
if request.GetUser() != "" {
|
if request.GetUser() != "" {
|
||||||
machines, err := api.h.ListMachinesByUser(request.GetUser())
|
machines, err := api.h.db.ListMachinesByUser(request.GetUser())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -327,7 +329,7 @@ func (api headscaleV1APIServer) ListMachines(
|
||||||
return &v1.ListMachinesResponse{Machines: response}, nil
|
return &v1.ListMachinesResponse{Machines: response}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
machines, err := api.h.ListMachines()
|
machines, err := api.h.db.ListMachines()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -352,12 +354,12 @@ func (api headscaleV1APIServer) MoveMachine(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.MoveMachineRequest,
|
request *v1.MoveMachineRequest,
|
||||||
) (*v1.MoveMachineResponse, error) {
|
) (*v1.MoveMachineResponse, error) {
|
||||||
machine, err := api.h.GetMachineByID(request.GetMachineId())
|
machine, err := api.h.db.GetMachineByID(request.GetMachineId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.SetMachineUser(machine, request.GetUser())
|
err = api.h.db.SetMachineUser(machine, request.GetUser())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -369,7 +371,7 @@ func (api headscaleV1APIServer) GetRoutes(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.GetRoutesRequest,
|
request *v1.GetRoutesRequest,
|
||||||
) (*v1.GetRoutesResponse, error) {
|
) (*v1.GetRoutesResponse, error) {
|
||||||
routes, err := api.h.GetRoutes()
|
routes, err := api.h.db.GetRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -383,7 +385,7 @@ func (api headscaleV1APIServer) EnableRoute(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.EnableRouteRequest,
|
request *v1.EnableRouteRequest,
|
||||||
) (*v1.EnableRouteResponse, error) {
|
) (*v1.EnableRouteResponse, error) {
|
||||||
err := api.h.EnableRoute(request.GetRouteId())
|
err := api.h.db.EnableRoute(request.GetRouteId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -395,7 +397,7 @@ func (api headscaleV1APIServer) DisableRoute(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DisableRouteRequest,
|
request *v1.DisableRouteRequest,
|
||||||
) (*v1.DisableRouteResponse, error) {
|
) (*v1.DisableRouteResponse, error) {
|
||||||
err := api.h.DisableRoute(request.GetRouteId())
|
err := api.h.db.DisableRoute(request.GetRouteId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -407,12 +409,12 @@ func (api headscaleV1APIServer) GetMachineRoutes(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.GetMachineRoutesRequest,
|
request *v1.GetMachineRoutesRequest,
|
||||||
) (*v1.GetMachineRoutesResponse, error) {
|
) (*v1.GetMachineRoutesResponse, error) {
|
||||||
machine, err := api.h.GetMachineByID(request.GetMachineId())
|
machine, err := api.h.db.GetMachineByID(request.GetMachineId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := api.h.GetMachineRoutes(machine)
|
routes, err := api.h.db.GetMachineRoutes(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -426,7 +428,7 @@ func (api headscaleV1APIServer) DeleteRoute(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteRouteRequest,
|
request *v1.DeleteRouteRequest,
|
||||||
) (*v1.DeleteRouteResponse, error) {
|
) (*v1.DeleteRouteResponse, error) {
|
||||||
err := api.h.DeleteRoute(request.GetRouteId())
|
err := api.h.db.DeleteRoute(request.GetRouteId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -443,7 +445,7 @@ func (api headscaleV1APIServer) CreateApiKey(
|
||||||
expiration = request.GetExpiration().AsTime()
|
expiration = request.GetExpiration().AsTime()
|
||||||
}
|
}
|
||||||
|
|
||||||
apiKey, _, err := api.h.CreateAPIKey(
|
apiKey, _, err := api.h.db.CreateAPIKey(
|
||||||
&expiration,
|
&expiration,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -460,12 +462,12 @@ func (api headscaleV1APIServer) ExpireApiKey(
|
||||||
var apiKey *APIKey
|
var apiKey *APIKey
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
apiKey, err = api.h.GetAPIKey(request.Prefix)
|
apiKey, err = api.h.db.GetAPIKey(request.Prefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.ExpireAPIKey(apiKey)
|
err = api.h.db.ExpireAPIKey(apiKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -477,7 +479,7 @@ func (api headscaleV1APIServer) ListApiKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ListApiKeysRequest,
|
request *v1.ListApiKeysRequest,
|
||||||
) (*v1.ListApiKeysResponse, error) {
|
) (*v1.ListApiKeysResponse, error) {
|
||||||
apiKeys, err := api.h.ListAPIKeys()
|
apiKeys, err := api.h.db.ListAPIKeys()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -495,12 +497,12 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DebugCreateMachineRequest,
|
request *v1.DebugCreateMachineRequest,
|
||||||
) (*v1.DebugCreateMachineResponse, error) {
|
) (*v1.DebugCreateMachineResponse, error) {
|
||||||
user, err := api.h.GetUser(request.GetUser())
|
user, err := api.h.db.GetUser(request.GetUser())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := stringToIPPrefix(request.GetRoutes())
|
routes, err := util.StringToIPPrefix(request.GetRoutes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -517,7 +519,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
||||||
Hostname: "DebugTestMachine",
|
Hostname: "DebugTestMachine",
|
||||||
}
|
}
|
||||||
|
|
||||||
givenName, err := api.h.GenerateGivenName(request.GetKey(), request.GetName())
|
givenName, err := api.h.db.GenerateGivenName(request.GetKey(), request.GetName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -542,7 +544,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
||||||
}
|
}
|
||||||
|
|
||||||
api.h.registrationCache.Set(
|
api.h.registrationCache.Set(
|
||||||
NodePublicKeyStripPrefix(nodeKey),
|
util.NodePublicKeyStripPrefix(nodeKey),
|
||||||
newMachine,
|
newMachine,
|
||||||
registerCacheExpiration,
|
registerCacheExpiration,
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,6 +11,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/patrickmn/go-cache"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
|
@ -21,23 +23,23 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ErrMachineNotFound = Error("machine not found")
|
|
||||||
ErrMachineRouteIsNotAvailable = Error("route is not available on machine")
|
|
||||||
ErrMachineAddressesInvalid = Error("failed to parse machine addresses")
|
|
||||||
ErrMachineNotFoundRegistrationCache = Error(
|
|
||||||
"machine not found in registration cache",
|
|
||||||
)
|
|
||||||
ErrCouldNotConvertMachineInterface = Error("failed to convert machine interface")
|
|
||||||
ErrHostnameTooLong = Error("Hostname too long")
|
|
||||||
ErrDifferentRegisteredUser = Error(
|
|
||||||
"machine was previously registered with a different user",
|
|
||||||
)
|
|
||||||
MachineGivenNameHashLength = 8
|
MachineGivenNameHashLength = 8
|
||||||
MachineGivenNameTrimSize = 2
|
MachineGivenNameTrimSize = 2
|
||||||
|
maxHostnameLength = 255
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
maxHostnameLength = 255
|
ErrMachineNotFound = errors.New("machine not found")
|
||||||
|
ErrMachineRouteIsNotAvailable = errors.New("route is not available on machine")
|
||||||
|
ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses")
|
||||||
|
ErrMachineNotFoundRegistrationCache = errors.New(
|
||||||
|
"machine not found in registration cache",
|
||||||
|
)
|
||||||
|
ErrCouldNotConvertMachineInterface = errors.New("failed to convert machine interface")
|
||||||
|
ErrHostnameTooLong = errors.New("hostname too long")
|
||||||
|
ErrDifferentRegisteredUser = errors.New(
|
||||||
|
"machine was previously registered with a different user",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Machine is a Headscale client.
|
// Machine is a Headscale client.
|
||||||
|
@ -188,8 +190,10 @@ func (machine *Machine) canAccess(filter []tailcfg.FilterRule, machine2 *Machine
|
||||||
|
|
||||||
// filterMachinesByACL wrapper function to not have devs pass around locks and maps
|
// filterMachinesByACL wrapper function to not have devs pass around locks and maps
|
||||||
// related to the application outside of tests.
|
// related to the application outside of tests.
|
||||||
func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines) Machines {
|
func (hsdb *HSDatabase) filterMachinesByACL(
|
||||||
return filterMachinesByACL(currentMachine, peers, h.aclRules)
|
aclRules []tailcfg.FilterRule,
|
||||||
|
currentMachine *Machine, peers Machines) Machines {
|
||||||
|
return filterMachinesByACL(currentMachine, peers, aclRules)
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
|
// filterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
|
||||||
|
@ -213,14 +217,14 @@ func filterMachinesByACL(
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
|
func (hsdb *HSDatabase) ListPeers(machine *Machine) (Machines, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
Msg("Finding direct peers")
|
Msg("Finding direct peers")
|
||||||
|
|
||||||
machines := Machines{}
|
machines := Machines{}
|
||||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?",
|
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?",
|
||||||
machine.NodeKey).Find(&machines).Error; err != nil {
|
machine.NodeKey).Find(&machines).Error; err != nil {
|
||||||
log.Error().Err(err).Msg("Error accessing db")
|
log.Error().Err(err).Msg("Error accessing db")
|
||||||
|
|
||||||
|
@ -237,23 +241,27 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
|
||||||
return machines, nil
|
return machines, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
|
func (hsdb *HSDatabase) getPeers(
|
||||||
|
aclPolicy *ACLPolicy,
|
||||||
|
aclRules []tailcfg.FilterRule,
|
||||||
|
machine *Machine,
|
||||||
|
) (Machines, error) {
|
||||||
var peers Machines
|
var peers Machines
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// If ACLs rules are defined, filter visible host list with the ACLs
|
// If ACLs rules are defined, filter visible host list with the ACLs
|
||||||
// else use the classic user scope
|
// else use the classic user scope
|
||||||
if h.aclPolicy != nil {
|
if aclPolicy != nil {
|
||||||
var machines []Machine
|
var machines []Machine
|
||||||
machines, err = h.ListMachines()
|
machines, err = hsdb.ListMachines()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error retrieving list of machines")
|
log.Error().Err(err).Msg("Error retrieving list of machines")
|
||||||
|
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
peers = h.filterMachinesByACL(machine, machines)
|
peers = hsdb.filterMachinesByACL(aclRules, machine, machines)
|
||||||
} else {
|
} else {
|
||||||
peers, err = h.ListPeers(machine)
|
peers, err = hsdb.ListPeers(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -275,10 +283,14 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
|
||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) {
|
func (hsdb *HSDatabase) getValidPeers(
|
||||||
|
aclPolicy *ACLPolicy,
|
||||||
|
aclRules []tailcfg.FilterRule,
|
||||||
|
machine *Machine,
|
||||||
|
) (Machines, error) {
|
||||||
validPeers := make(Machines, 0)
|
validPeers := make(Machines, 0)
|
||||||
|
|
||||||
peers, err := h.getPeers(machine)
|
peers, err := hsdb.getPeers(aclPolicy, aclRules, machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
@ -292,18 +304,18 @@ func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) {
|
||||||
return validPeers, nil
|
return validPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) ListMachines() ([]Machine, error) {
|
func (hsdb *HSDatabase) ListMachines() ([]Machine, error) {
|
||||||
machines := []Machine{}
|
machines := []Machine{}
|
||||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil {
|
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return machines, nil
|
return machines, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) ListMachinesByGivenName(givenName string) ([]Machine, error) {
|
func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) ([]Machine, error) {
|
||||||
machines := []Machine{}
|
machines := []Machine{}
|
||||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil {
|
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -311,8 +323,8 @@ func (h *Headscale) ListMachinesByGivenName(givenName string) ([]Machine, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachine finds a Machine by name and user and returns the Machine struct.
|
// GetMachine finds a Machine by name and user and returns the Machine struct.
|
||||||
func (h *Headscale) GetMachine(user string, name string) (*Machine, error) {
|
func (hsdb *HSDatabase) GetMachine(user string, name string) (*Machine, error) {
|
||||||
machines, err := h.ListMachinesByUser(user)
|
machines, err := hsdb.ListMachinesByUser(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -327,8 +339,8 @@ func (h *Headscale) GetMachine(user string, name string) (*Machine, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachineByGivenName finds a Machine by given name and user and returns the Machine struct.
|
// GetMachineByGivenName finds a Machine by given name and user and returns the Machine struct.
|
||||||
func (h *Headscale) GetMachineByGivenName(user string, givenName string) (*Machine, error) {
|
func (hsdb *HSDatabase) GetMachineByGivenName(user string, givenName string) (*Machine, error) {
|
||||||
machines, err := h.ListMachinesByUser(user)
|
machines, err := hsdb.ListMachinesByUser(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -343,9 +355,9 @@ func (h *Headscale) GetMachineByGivenName(user string, givenName string) (*Machi
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachineByID finds a Machine by ID and returns the Machine struct.
|
// GetMachineByID finds a Machine by ID and returns the Machine struct.
|
||||||
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
func (hsdb *HSDatabase) GetMachineByID(id uint64) (*Machine, error) {
|
||||||
m := Machine{}
|
m := Machine{}
|
||||||
if result := h.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil {
|
if result := hsdb.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -353,11 +365,11 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
|
// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
|
||||||
func (h *Headscale) GetMachineByMachineKey(
|
func (hsdb *HSDatabase) GetMachineByMachineKey(
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*Machine, error) {
|
) (*Machine, error) {
|
||||||
m := Machine{}
|
m := Machine{}
|
||||||
if result := h.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", MachinePublicKeyStripPrefix(machineKey)); result.Error != nil {
|
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,12 +377,12 @@ func (h *Headscale) GetMachineByMachineKey(
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachineByNodeKey finds a Machine by its current NodeKey.
|
// GetMachineByNodeKey finds a Machine by its current NodeKey.
|
||||||
func (h *Headscale) GetMachineByNodeKey(
|
func (hsdb *HSDatabase) GetMachineByNodeKey(
|
||||||
nodeKey key.NodePublic,
|
nodeKey key.NodePublic,
|
||||||
) (*Machine, error) {
|
) (*Machine, error) {
|
||||||
machine := Machine{}
|
machine := Machine{}
|
||||||
if result := h.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?",
|
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?",
|
||||||
NodePublicKeyStripPrefix(nodeKey)); result.Error != nil {
|
util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -378,14 +390,14 @@ func (h *Headscale) GetMachineByNodeKey(
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachineByAnyNodeKey finds a Machine by its MachineKey, its current NodeKey or the old one, and returns the Machine struct.
|
// GetMachineByAnyNodeKey finds a Machine by its MachineKey, its current NodeKey or the old one, and returns the Machine struct.
|
||||||
func (h *Headscale) GetMachineByAnyKey(
|
func (hsdb *HSDatabase) GetMachineByAnyKey(
|
||||||
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
|
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
|
||||||
) (*Machine, error) {
|
) (*Machine, error) {
|
||||||
machine := Machine{}
|
machine := Machine{}
|
||||||
if result := h.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?",
|
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?",
|
||||||
MachinePublicKeyStripPrefix(machineKey),
|
util.MachinePublicKeyStripPrefix(machineKey),
|
||||||
NodePublicKeyStripPrefix(nodeKey),
|
util.NodePublicKeyStripPrefix(nodeKey),
|
||||||
NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
|
util.NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -394,8 +406,8 @@ func (h *Headscale) GetMachineByAnyKey(
|
||||||
|
|
||||||
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
|
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
|
||||||
// and updates it with the latest data from the database.
|
// and updates it with the latest data from the database.
|
||||||
func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
|
func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *Machine) error {
|
||||||
if result := h.db.Find(machine).First(&machine); result.Error != nil {
|
if result := hsdb.db.Find(machine).First(&machine); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -403,20 +415,28 @@ func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTags takes a Machine struct pointer and update the forced tags.
|
// SetTags takes a Machine struct pointer and update the forced tags.
|
||||||
func (h *Headscale) SetTags(machine *Machine, tags []string) error {
|
func (hsdb *HSDatabase) SetTags(
|
||||||
|
machine *Machine,
|
||||||
|
tags []string,
|
||||||
|
// TODO(kradalby): This is a temporary measure to be able to detach the
|
||||||
|
// database completely from the global h. In the future, as part of this
|
||||||
|
// reorg, the rules will be generated on a per node basis, and not be prone
|
||||||
|
// to throwing error at save.
|
||||||
|
updateACL func() error) error {
|
||||||
newTags := []string{}
|
newTags := []string{}
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
if !contains(newTags, tag) {
|
if !util.StringOrPrefixListContains(newTags, tag) {
|
||||||
newTags = append(newTags, tag)
|
newTags = append(newTags, tag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
machine.ForcedTags = newTags
|
machine.ForcedTags = newTags
|
||||||
if err := h.UpdateACLRules(); err != nil && !errors.Is(err, errEmptyPolicy) {
|
if err := updateACL(); err != nil && !errors.Is(err, errEmptyPolicy) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
h.setLastStateChangeToNow()
|
|
||||||
|
|
||||||
if err := h.db.Save(machine).Error; err != nil {
|
hsdb.notifyStateChange()
|
||||||
|
|
||||||
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
|
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -424,13 +444,13 @@ func (h *Headscale) SetTags(machine *Machine, tags []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpireMachine takes a Machine struct and sets the expire field to now.
|
// ExpireMachine takes a Machine struct and sets the expire field to now.
|
||||||
func (h *Headscale) ExpireMachine(machine *Machine) error {
|
func (hsdb *HSDatabase) ExpireMachine(machine *Machine) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
machine.Expiry = &now
|
machine.Expiry = &now
|
||||||
|
|
||||||
h.setLastStateChangeToNow()
|
hsdb.notifyStateChange()
|
||||||
|
|
||||||
if err := h.db.Save(machine).Error; err != nil {
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf("failed to expire machine in the database: %w", err)
|
return fmt.Errorf("failed to expire machine in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -439,7 +459,7 @@ func (h *Headscale) ExpireMachine(machine *Machine) error {
|
||||||
|
|
||||||
// RenameMachine takes a Machine struct and a new GivenName for the machines
|
// RenameMachine takes a Machine struct and a new GivenName for the machines
|
||||||
// and renames it.
|
// and renames it.
|
||||||
func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
|
func (hsdb *HSDatabase) RenameMachine(machine *Machine, newName string) error {
|
||||||
err := CheckForFQDNRules(
|
err := CheckForFQDNRules(
|
||||||
newName,
|
newName,
|
||||||
)
|
)
|
||||||
|
@ -455,9 +475,9 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
|
||||||
}
|
}
|
||||||
machine.GivenName = newName
|
machine.GivenName = newName
|
||||||
|
|
||||||
h.setLastStateChangeToNow()
|
hsdb.notifyStateChange()
|
||||||
|
|
||||||
if err := h.db.Save(machine).Error; err != nil {
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -465,15 +485,15 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshMachine takes a Machine struct and sets the expire field to now.
|
// RefreshMachine takes a Machine struct and sets the expire field to now.
|
||||||
func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error {
|
func (hsdb *HSDatabase) RefreshMachine(machine *Machine, expiry time.Time) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
machine.LastSuccessfulUpdate = &now
|
machine.LastSuccessfulUpdate = &now
|
||||||
machine.Expiry = &expiry
|
machine.Expiry = &expiry
|
||||||
|
|
||||||
h.setLastStateChangeToNow()
|
hsdb.notifyStateChange()
|
||||||
|
|
||||||
if err := h.db.Save(machine).Error; err != nil {
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"failed to refresh machine (update expiration) in the database: %w",
|
"failed to refresh machine (update expiration) in the database: %w",
|
||||||
err,
|
err,
|
||||||
|
@ -484,21 +504,21 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMachine softs deletes a Machine from the database.
|
// DeleteMachine softs deletes a Machine from the database.
|
||||||
func (h *Headscale) DeleteMachine(machine *Machine) error {
|
func (hsdb *HSDatabase) DeleteMachine(machine *Machine) error {
|
||||||
err := h.DeleteMachineRoutes(machine)
|
err := hsdb.DeleteMachineRoutes(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Delete(&machine).Error; err != nil {
|
if err := hsdb.db.Delete(&machine).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) TouchMachine(machine *Machine) error {
|
func (hsdb *HSDatabase) TouchMachine(machine *Machine) error {
|
||||||
return h.db.Updates(Machine{
|
return hsdb.db.Updates(Machine{
|
||||||
ID: machine.ID,
|
ID: machine.ID,
|
||||||
LastSeen: machine.LastSeen,
|
LastSeen: machine.LastSeen,
|
||||||
LastSuccessfulUpdate: machine.LastSuccessfulUpdate,
|
LastSuccessfulUpdate: machine.LastSuccessfulUpdate,
|
||||||
|
@ -506,13 +526,13 @@ func (h *Headscale) TouchMachine(machine *Machine) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// HardDeleteMachine hard deletes a Machine from the database.
|
// HardDeleteMachine hard deletes a Machine from the database.
|
||||||
func (h *Headscale) HardDeleteMachine(machine *Machine) error {
|
func (hsdb *HSDatabase) HardDeleteMachine(machine *Machine) error {
|
||||||
err := h.DeleteMachineRoutes(machine)
|
err := hsdb.DeleteMachineRoutes(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
|
if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -524,8 +544,8 @@ func (machine *Machine) GetHostInfo() tailcfg.Hostinfo {
|
||||||
return tailcfg.Hostinfo(machine.HostInfo)
|
return tailcfg.Hostinfo(machine.HostInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) isOutdated(machine *Machine) bool {
|
func (hsdb *HSDatabase) isOutdated(machine *Machine, lastChange time.Time) bool {
|
||||||
if err := h.UpdateMachineFromDatabase(machine); err != nil {
|
if err := hsdb.UpdateMachineFromDatabase(machine); err != nil {
|
||||||
// It does not seem meaningful to propagate this error as the end result
|
// It does not seem meaningful to propagate this error as the end result
|
||||||
// will have to be that the machine has to be considered outdated.
|
// will have to be that the machine has to be considered outdated.
|
||||||
return true
|
return true
|
||||||
|
@ -536,7 +556,6 @@ func (h *Headscale) isOutdated(machine *Machine) bool {
|
||||||
// TODO(kradalby): Only request updates from users where we can talk to nodes
|
// TODO(kradalby): Only request updates from users where we can talk to nodes
|
||||||
// This would mostly be for a bit of performance, and can be calculated based on
|
// This would mostly be for a bit of performance, and can be calculated based on
|
||||||
// ACLs.
|
// ACLs.
|
||||||
lastChange := h.getLastStateChange()
|
|
||||||
lastUpdate := machine.CreatedAt
|
lastUpdate := machine.CreatedAt
|
||||||
if machine.LastSuccessfulUpdate != nil {
|
if machine.LastSuccessfulUpdate != nil {
|
||||||
lastUpdate = *machine.LastSuccessfulUpdate
|
lastUpdate = *machine.LastSuccessfulUpdate
|
||||||
|
@ -576,15 +595,16 @@ func (machines MachinesP) String() string {
|
||||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) toNodes(
|
func (hsdb *HSDatabase) toNodes(
|
||||||
machines Machines,
|
machines Machines,
|
||||||
|
aclPolicy *ACLPolicy,
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
dnsConfig *tailcfg.DNSConfig,
|
dnsConfig *tailcfg.DNSConfig,
|
||||||
) ([]*tailcfg.Node, error) {
|
) ([]*tailcfg.Node, error) {
|
||||||
nodes := make([]*tailcfg.Node, len(machines))
|
nodes := make([]*tailcfg.Node, len(machines))
|
||||||
|
|
||||||
for index, machine := range machines {
|
for index, machine := range machines {
|
||||||
node, err := h.toNode(machine, baseDomain, dnsConfig)
|
node, err := hsdb.toNode(machine, aclPolicy, baseDomain, dnsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -597,13 +617,14 @@ func (h *Headscale) toNodes(
|
||||||
|
|
||||||
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
|
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
|
||||||
// as per the expected behaviour in the official SaaS.
|
// as per the expected behaviour in the official SaaS.
|
||||||
func (h *Headscale) toNode(
|
func (hsdb *HSDatabase) toNode(
|
||||||
machine Machine,
|
machine Machine,
|
||||||
|
aclPolicy *ACLPolicy,
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
dnsConfig *tailcfg.DNSConfig,
|
dnsConfig *tailcfg.DNSConfig,
|
||||||
) (*tailcfg.Node, error) {
|
) (*tailcfg.Node, error) {
|
||||||
var nodeKey key.NodePublic
|
var nodeKey key.NodePublic
|
||||||
err := nodeKey.UnmarshalText([]byte(NodePublicKeyEnsurePrefix(machine.NodeKey)))
|
err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -617,7 +638,7 @@ func (h *Headscale) toNode(
|
||||||
// MachineKey is only used in the legacy protocol
|
// MachineKey is only used in the legacy protocol
|
||||||
if machine.MachineKey != "" {
|
if machine.MachineKey != "" {
|
||||||
err = machineKey.UnmarshalText(
|
err = machineKey.UnmarshalText(
|
||||||
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
|
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
|
||||||
|
@ -627,7 +648,7 @@ func (h *Headscale) toNode(
|
||||||
var discoKey key.DiscoPublic
|
var discoKey key.DiscoPublic
|
||||||
if machine.DiscoKey != "" {
|
if machine.DiscoKey != "" {
|
||||||
err := discoKey.UnmarshalText(
|
err := discoKey.UnmarshalText(
|
||||||
[]byte(DiscoPublicKeyEnsurePrefix(machine.DiscoKey)),
|
[]byte(util.DiscoPublicKeyEnsurePrefix(machine.DiscoKey)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse disco public key: %w", err)
|
return nil, fmt.Errorf("failed to parse disco public key: %w", err)
|
||||||
|
@ -646,13 +667,13 @@ func (h *Headscale) toNode(
|
||||||
[]netip.Prefix{},
|
[]netip.Prefix{},
|
||||||
addrs...) // we append the node own IP, as it is required by the clients
|
addrs...) // we append the node own IP, as it is required by the clients
|
||||||
|
|
||||||
primaryRoutes, err := h.getMachinePrimaryRoutes(&machine)
|
primaryRoutes, err := hsdb.getMachinePrimaryRoutes(&machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
primaryPrefixes := Routes(primaryRoutes).toPrefixes()
|
primaryPrefixes := Routes(primaryRoutes).toPrefixes()
|
||||||
|
|
||||||
machineRoutes, err := h.GetMachineRoutes(&machine)
|
machineRoutes, err := hsdb.GetMachineRoutes(&machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -699,13 +720,13 @@ func (h *Headscale) toNode(
|
||||||
|
|
||||||
online := machine.isOnline()
|
online := machine.isOnline()
|
||||||
|
|
||||||
tags, _ := getTags(h.aclPolicy, machine, h.cfg.OIDC.StripEmaildomain)
|
tags, _ := getTags(aclPolicy, machine, hsdb.stripEmailDomain)
|
||||||
tags = lo.Uniq(append(tags, machine.ForcedTags...))
|
tags = lo.Uniq(append(tags, machine.ForcedTags...))
|
||||||
|
|
||||||
node := tailcfg.Node{
|
node := tailcfg.Node{
|
||||||
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
|
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
|
||||||
StableID: tailcfg.StableNodeID(
|
StableID: tailcfg.StableNodeID(
|
||||||
strconv.FormatUint(machine.ID, Base10),
|
strconv.FormatUint(machine.ID, util.Base10),
|
||||||
), // in headscale, unlike tailcontrol server, IDs are permanent
|
), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||||
Name: hostname,
|
Name: hostname,
|
||||||
|
|
||||||
|
@ -827,7 +848,8 @@ func getTags(
|
||||||
return validTags, invalidTags
|
return validTags, invalidTags
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) RegisterMachineFromAuthCallback(
|
func (hsdb *HSDatabase) RegisterMachineFromAuthCallback(
|
||||||
|
cache *cache.Cache,
|
||||||
nodeKeyStr string,
|
nodeKeyStr string,
|
||||||
userName string,
|
userName string,
|
||||||
machineExpiry *time.Time,
|
machineExpiry *time.Time,
|
||||||
|
@ -846,9 +868,9 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
|
||||||
Str("expiresAt", fmt.Sprintf("%v", machineExpiry)).
|
Str("expiresAt", fmt.Sprintf("%v", machineExpiry)).
|
||||||
Msg("Registering machine from API/CLI or auth callback")
|
Msg("Registering machine from API/CLI or auth callback")
|
||||||
|
|
||||||
if machineInterface, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(nodeKey)); ok {
|
if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok {
|
||||||
if registrationMachine, ok := machineInterface.(Machine); ok {
|
if registrationMachine, ok := machineInterface.(Machine); ok {
|
||||||
user, err := h.GetUser(userName)
|
user, err := hsdb.GetUser(userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"failed to find user in register machine from auth callback, %w",
|
"failed to find user in register machine from auth callback, %w",
|
||||||
|
@ -869,12 +891,12 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
|
||||||
registrationMachine.Expiry = machineExpiry
|
registrationMachine.Expiry = machineExpiry
|
||||||
}
|
}
|
||||||
|
|
||||||
machine, err := h.RegisterMachine(
|
machine, err := hsdb.RegisterMachine(
|
||||||
registrationMachine,
|
registrationMachine,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
h.registrationCache.Delete(nodeKeyStr)
|
cache.Delete(nodeKeyStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return machine, err
|
return machine, err
|
||||||
|
@ -887,7 +909,7 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
||||||
func (h *Headscale) RegisterMachine(machine Machine,
|
func (hsdb *HSDatabase) RegisterMachine(machine Machine,
|
||||||
) (*Machine, error) {
|
) (*Machine, error) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
|
@ -900,7 +922,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
||||||
// so we store the machine.Expire and machine.Nodekey that has been set when
|
// so we store the machine.Expire and machine.Nodekey that has been set when
|
||||||
// adding it to the registrationCache
|
// adding it to the registrationCache
|
||||||
if len(machine.IPAddresses) > 0 {
|
if len(machine.IPAddresses) > 0 {
|
||||||
if err := h.db.Save(&machine).Error; err != nil {
|
if err := hsdb.db.Save(&machine).Error; err != nil {
|
||||||
return nil, fmt.Errorf("failed register existing machine in the database: %w", err)
|
return nil, fmt.Errorf("failed register existing machine in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -915,10 +937,10 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
||||||
return &machine, nil
|
return &machine, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
h.ipAllocationMutex.Lock()
|
hsdb.ipAllocationMutex.Lock()
|
||||||
defer h.ipAllocationMutex.Unlock()
|
defer hsdb.ipAllocationMutex.Unlock()
|
||||||
|
|
||||||
ips, err := h.getAvailableIPs()
|
ips, err := hsdb.getAvailableIPs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -931,7 +953,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
||||||
|
|
||||||
machine.IPAddresses = ips
|
machine.IPAddresses = ips
|
||||||
|
|
||||||
if err := h.db.Save(&machine).Error; err != nil {
|
if err := hsdb.db.Save(&machine).Error; err != nil {
|
||||||
return nil, fmt.Errorf("failed register(save) machine in the database: %w", err)
|
return nil, fmt.Errorf("failed register(save) machine in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -945,10 +967,10 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAdvertisedRoutes returns the routes that are be advertised by the given machine.
|
// GetAdvertisedRoutes returns the routes that are be advertised by the given machine.
|
||||||
func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) {
|
func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) {
|
||||||
routes := []Route{}
|
routes := []Route{}
|
||||||
|
|
||||||
err := h.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("machine_id = ? AND advertised = ?", machine.ID, true).Find(&routes).Error
|
Where("machine_id = ? AND advertised = ?", machine.ID, true).Find(&routes).Error
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
@ -970,10 +992,10 @@ func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEnabledRoutes returns the routes that are enabled for the machine.
|
// GetEnabledRoutes returns the routes that are enabled for the machine.
|
||||||
func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) {
|
func (hsdb *HSDatabase) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) {
|
||||||
routes := []Route{}
|
routes := []Route{}
|
||||||
|
|
||||||
err := h.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("machine_id = ? AND advertised = ? AND enabled = ?", machine.ID, true, true).
|
Where("machine_id = ? AND advertised = ? AND enabled = ?", machine.ID, true, true).
|
||||||
Find(&routes).Error
|
Find(&routes).Error
|
||||||
|
@ -995,13 +1017,13 @@ func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) {
|
||||||
return prefixes, nil
|
return prefixes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool {
|
func (hsdb *HSDatabase) IsRoutesEnabled(machine *Machine, routeStr string) bool {
|
||||||
route, err := netip.ParsePrefix(routeStr)
|
route, err := netip.ParsePrefix(routeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
enabledRoutes, err := h.GetEnabledRoutes(machine)
|
enabledRoutes, err := hsdb.GetEnabledRoutes(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Could not get enabled routes")
|
log.Error().Err(err).Msg("Could not get enabled routes")
|
||||||
|
|
||||||
|
@ -1018,7 +1040,7 @@ func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// enableRoutes enables new routes based on a list of new routes.
|
// enableRoutes enables new routes based on a list of new routes.
|
||||||
func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
|
func (hsdb *HSDatabase) enableRoutes(machine *Machine, routeStrs ...string) error {
|
||||||
newRoutes := make([]netip.Prefix, len(routeStrs))
|
newRoutes := make([]netip.Prefix, len(routeStrs))
|
||||||
for index, routeStr := range routeStrs {
|
for index, routeStr := range routeStrs {
|
||||||
route, err := netip.ParsePrefix(routeStr)
|
route, err := netip.ParsePrefix(routeStr)
|
||||||
|
@ -1029,13 +1051,13 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
|
||||||
newRoutes[index] = route
|
newRoutes[index] = route
|
||||||
}
|
}
|
||||||
|
|
||||||
advertisedRoutes, err := h.GetAdvertisedRoutes(machine)
|
advertisedRoutes, err := hsdb.GetAdvertisedRoutes(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
for _, newRoute := range newRoutes {
|
||||||
if !contains(advertisedRoutes, newRoute) {
|
if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"route (%s) is not available on node %s: %w",
|
"route (%s) is not available on node %s: %w",
|
||||||
machine.Hostname,
|
machine.Hostname,
|
||||||
|
@ -1047,7 +1069,7 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
|
||||||
// Separate loop so we don't leave things in a half-updated state
|
// Separate loop so we don't leave things in a half-updated state
|
||||||
for _, prefix := range newRoutes {
|
for _, prefix := range newRoutes {
|
||||||
route := Route{}
|
route := Route{}
|
||||||
err := h.db.Preload("Machine").
|
err := hsdb.db.Preload("Machine").
|
||||||
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).
|
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).
|
||||||
First(&route).Error
|
First(&route).Error
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -1056,10 +1078,10 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
|
||||||
// Mark already as primary if there is only this node offering this subnet
|
// Mark already as primary if there is only this node offering this subnet
|
||||||
// (and is not an exit route)
|
// (and is not an exit route)
|
||||||
if !route.isExitRoute() {
|
if !route.isExitRoute() {
|
||||||
route.IsPrimary = h.isUniquePrefix(route)
|
route.IsPrimary = hsdb.isUniquePrefix(route)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.db.Save(&route).Error
|
err = hsdb.db.Save(&route).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to enable route: %w", err)
|
return fmt.Errorf("failed to enable route: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -1068,19 +1090,19 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.setLastStateChangeToNow()
|
hsdb.notifyStateChange()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy.
|
// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy.
|
||||||
func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
|
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(aclPolicy *ACLPolicy, machine *Machine) error {
|
||||||
if len(machine.IPAddresses) == 0 {
|
if len(machine.IPAddresses) == 0 {
|
||||||
return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := []Route{}
|
routes := []Route{}
|
||||||
err := h.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("machine_id = ? AND advertised = true AND enabled = false", machine.ID).
|
Where("machine_id = ? AND advertised = true AND enabled = false", machine.ID).
|
||||||
Find(&routes).Error
|
Find(&routes).Error
|
||||||
|
@ -1097,7 +1119,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
|
||||||
approvedRoutes := []Route{}
|
approvedRoutes := []Route{}
|
||||||
|
|
||||||
for _, advertisedRoute := range routes {
|
for _, advertisedRoute := range routes {
|
||||||
routeApprovers, err := h.aclPolicy.AutoApprovers.GetRouteApprovers(
|
routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
|
||||||
netip.Prefix(advertisedRoute.Prefix),
|
netip.Prefix(advertisedRoute.Prefix),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1113,7 +1135,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
|
||||||
if approvedAlias == machine.User.Name {
|
if approvedAlias == machine.User.Name {
|
||||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||||
} else {
|
} else {
|
||||||
approvedIps, err := h.aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, h.cfg.OIDC.StripEmaildomain)
|
approvedIps, err := aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, hsdb.stripEmailDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).
|
log.Err(err).
|
||||||
Str("alias", approvedAlias).
|
Str("alias", approvedAlias).
|
||||||
|
@ -1132,7 +1154,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
|
||||||
|
|
||||||
for i, approvedRoute := range approvedRoutes {
|
for i, approvedRoute := range approvedRoutes {
|
||||||
approvedRoutes[i].Enabled = true
|
approvedRoutes[i].Enabled = true
|
||||||
err = h.db.Save(&approvedRoutes[i]).Error
|
err = hsdb.db.Save(&approvedRoutes[i]).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).
|
log.Err(err).
|
||||||
Str("approvedRoute", approvedRoute.String()).
|
Str("approvedRoute", approvedRoute.String()).
|
||||||
|
@ -1146,10 +1168,10 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||||
normalizedHostname, err := NormalizeToFQDNRules(
|
normalizedHostname, err := NormalizeToFQDNRules(
|
||||||
suppliedName,
|
suppliedName,
|
||||||
h.cfg.OIDC.StripEmaildomain,
|
hsdb.stripEmailDomain,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -1162,7 +1184,7 @@ func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (s
|
||||||
normalizedHostname = normalizedHostname[:trimmedHostnameLength]
|
normalizedHostname = normalizedHostname[:trimmedHostnameLength]
|
||||||
}
|
}
|
||||||
|
|
||||||
suffix, err := GenerateRandomStringDNSSafe(MachineGivenNameHashLength)
|
suffix, err := util.GenerateRandomStringDNSSafe(MachineGivenNameHashLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -1173,21 +1195,21 @@ func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (s
|
||||||
return normalizedHostname, nil
|
return normalizedHostname, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) GenerateGivenName(machineKey string, suppliedName string) (string, error) {
|
func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) {
|
||||||
givenName, err := h.generateGivenName(suppliedName, false)
|
givenName, err := hsdb.generateGivenName(suppliedName, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
|
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
|
||||||
machines, err := h.ListMachinesByGivenName(givenName)
|
machines, err := hsdb.ListMachinesByGivenName(givenName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, machine := range machines {
|
for _, machine := range machines {
|
||||||
if machine.MachineKey != machineKey && machine.GivenName == givenName {
|
if machine.MachineKey != machineKey && machine.GivenName == givenName {
|
||||||
postfixedName, err := h.generateGivenName(suppliedName, true)
|
postfixedName, err := hsdb.generateGivenName(suppliedName, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,19 +9,20 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestGetMachine(c *check.C) {
|
func (s *Suite) TestGetMachine(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "testmachine")
|
_, err = app.db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machine := &Machine{
|
machine := &Machine{
|
||||||
|
@ -34,20 +35,20 @@ func (s *Suite) TestGetMachine(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machine)
|
app.db.db.Save(machine)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "testmachine")
|
_, err = app.db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMachineByID(c *check.C) {
|
func (s *Suite) TestGetMachineByID(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachineByID(0)
|
_, err = app.db.GetMachineByID(0)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
|
@ -60,20 +61,20 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
_, err = app.GetMachineByID(0)
|
_, err = app.db.GetMachineByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachineByID(0)
|
_, err = app.db.GetMachineByID(0)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -81,28 +82,28 @@ func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
||||||
|
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()),
|
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||||
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
|
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "testmachine",
|
Hostname: "testmachine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
_, err = app.GetMachineByNodeKey(nodeKey.Public())
|
_, err = app.db.GetMachineByNodeKey(nodeKey.Public())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachineByID(0)
|
_, err = app.db.GetMachineByID(0)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -112,22 +113,22 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
||||||
|
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()),
|
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||||
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
|
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "testmachine",
|
Hostname: "testmachine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
_, err = app.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
_, err = app.db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDeleteMachine(c *check.C) {
|
func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
|
@ -139,17 +140,17 @@ func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(1),
|
AuthKeyID: uint(1),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
err = app.DeleteMachine(&machine)
|
err = app.db.DeleteMachine(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine(user.Name, "testmachine")
|
_, err = app.db.GetMachine(user.Name, "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
|
@ -161,23 +162,23 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(1),
|
AuthKeyID: uint(1),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
err = app.HardDeleteMachine(&machine)
|
err = app.db.HardDeleteMachine(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine(user.Name, "testmachine3")
|
_, err = app.db.GetMachine(user.Name, "testmachine3")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestListPeers(c *check.C) {
|
func (s *Suite) TestListPeers(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachineByID(0)
|
_, err = app.db.GetMachineByID(0)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
for index := 0; index <= 10; index++ {
|
for index := 0; index <= 10; index++ {
|
||||||
|
@ -191,13 +192,13 @@ func (s *Suite) TestListPeers(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
}
|
}
|
||||||
|
|
||||||
machine0ByID, err := app.GetMachineByID(0)
|
machine0ByID, err := app.db.GetMachineByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
peersOfMachine0, err := app.ListPeers(machine0ByID)
|
peersOfMachine0, err := app.db.ListPeers(machine0ByID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(peersOfMachine0), check.Equals, 9)
|
c.Assert(len(peersOfMachine0), check.Equals, 9)
|
||||||
|
@ -215,14 +216,14 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
stor := make([]base, 0)
|
stor := make([]base, 0)
|
||||||
|
|
||||||
for _, name := range []string{"test", "admin"} {
|
for _, name := range []string{"test", "admin"} {
|
||||||
user, err := app.CreateUser(name)
|
user, err := app.db.CreateUser(name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
stor = append(stor, base{user, pak})
|
stor = append(stor, base{user, pak})
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := app.GetMachineByID(0)
|
_, err := app.db.GetMachineByID(0)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
for index := 0; index <= 10; index++ {
|
for index := 0; index <= 10; index++ {
|
||||||
|
@ -239,7 +240,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(stor[index%2].key.ID),
|
AuthKeyID: uint(stor[index%2].key.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.aclPolicy = &ACLPolicy{
|
app.aclPolicy = &ACLPolicy{
|
||||||
|
@ -266,19 +267,19 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
err = app.UpdateACLRules()
|
err = app.UpdateACLRules()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
adminMachine, err := app.GetMachineByID(1)
|
adminMachine, err := app.db.GetMachineByID(1)
|
||||||
c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User)
|
c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
testMachine, err := app.GetMachineByID(2)
|
testMachine, err := app.db.GetMachineByID(2)
|
||||||
c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User)
|
c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machines, err := app.ListMachines()
|
machines, err := app.db.ListMachines()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
peersOfTestMachine := app.filterMachinesByACL(testMachine, machines)
|
peersOfTestMachine := app.db.filterMachinesByACL(app.aclRules, testMachine, machines)
|
||||||
peersOfAdminMachine := app.filterMachinesByACL(adminMachine, machines)
|
peersOfAdminMachine := app.db.filterMachinesByACL(app.aclRules, adminMachine, machines)
|
||||||
|
|
||||||
c.Log(peersOfTestMachine)
|
c.Log(peersOfTestMachine)
|
||||||
c.Assert(len(peersOfTestMachine), check.Equals, 9)
|
c.Assert(len(peersOfTestMachine), check.Equals, 9)
|
||||||
|
@ -294,13 +295,13 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestExpireMachine(c *check.C) {
|
func (s *Suite) TestExpireMachine(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "testmachine")
|
_, err = app.db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machine := &Machine{
|
machine := &Machine{
|
||||||
|
@ -314,15 +315,15 @@ func (s *Suite) TestExpireMachine(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
Expiry: &time.Time{},
|
Expiry: &time.Time{},
|
||||||
}
|
}
|
||||||
app.db.Save(machine)
|
app.db.db.Save(machine)
|
||||||
|
|
||||||
machineFromDB, err := app.GetMachine("test", "testmachine")
|
machineFromDB, err := app.db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(machineFromDB, check.NotNil)
|
c.Assert(machineFromDB, check.NotNil)
|
||||||
|
|
||||||
c.Assert(machineFromDB.isExpired(), check.Equals, false)
|
c.Assert(machineFromDB.isExpired(), check.Equals, false)
|
||||||
|
|
||||||
err = app.ExpireMachine(machineFromDB)
|
err = app.db.ExpireMachine(machineFromDB)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(machineFromDB.isExpired(), check.Equals, true)
|
c.Assert(machineFromDB.isExpired(), check.Equals, true)
|
||||||
|
@ -350,13 +351,13 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGenerateGivenName(c *check.C) {
|
func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||||
user1, err := app.CreateUser("user-1")
|
user1, err := app.db.CreateUser("user-1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user1.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("user-1", "testmachine")
|
_, err = app.db.GetMachine("user-1", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machine := &Machine{
|
machine := &Machine{
|
||||||
|
@ -370,37 +371,37 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machine)
|
app.db.db.Save(machine)
|
||||||
|
|
||||||
givenName, err := app.GenerateGivenName("machine-key-2", "hostname-2")
|
givenName, err := app.db.GenerateGivenName("machine-key-2", "hostname-2")
|
||||||
comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict")
|
comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict")
|
||||||
c.Assert(err, check.IsNil, comment)
|
c.Assert(err, check.IsNil, comment)
|
||||||
c.Assert(givenName, check.Equals, "hostname-2", comment)
|
c.Assert(givenName, check.Equals, "hostname-2", comment)
|
||||||
|
|
||||||
givenName, err = app.GenerateGivenName("machine-key-1", "hostname-1")
|
givenName, err = app.db.GenerateGivenName("machine-key-1", "hostname-1")
|
||||||
comment = check.Commentf("Same user, same machine, same hostname, no conflict")
|
comment = check.Commentf("Same user, same machine, same hostname, no conflict")
|
||||||
c.Assert(err, check.IsNil, comment)
|
c.Assert(err, check.IsNil, comment)
|
||||||
c.Assert(givenName, check.Equals, "hostname-1", comment)
|
c.Assert(givenName, check.Equals, "hostname-1", comment)
|
||||||
|
|
||||||
givenName, err = app.GenerateGivenName("machine-key-2", "hostname-1")
|
givenName, err = app.db.GenerateGivenName("machine-key-2", "hostname-1")
|
||||||
comment = check.Commentf("Same user, unique machines, same hostname, conflict")
|
comment = check.Commentf("Same user, unique machines, same hostname, conflict")
|
||||||
c.Assert(err, check.IsNil, comment)
|
c.Assert(err, check.IsNil, comment)
|
||||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
||||||
|
|
||||||
givenName, err = app.GenerateGivenName("machine-key-2", "hostname-1")
|
givenName, err = app.db.GenerateGivenName("machine-key-2", "hostname-1")
|
||||||
comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
|
comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
|
||||||
c.Assert(err, check.IsNil, comment)
|
c.Assert(err, check.IsNil, comment)
|
||||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSetTags(c *check.C) {
|
func (s *Suite) TestSetTags(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "testmachine")
|
_, err = app.db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machine := &Machine{
|
machine := &Machine{
|
||||||
|
@ -413,21 +414,21 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machine)
|
app.db.db.Save(machine)
|
||||||
|
|
||||||
// assign simple tags
|
// assign simple tags
|
||||||
sTags := []string{"tag:test", "tag:foo"}
|
sTags := []string{"tag:test", "tag:foo"}
|
||||||
err = app.SetTags(machine, sTags)
|
err = app.db.SetTags(machine, sTags, app.UpdateACLRules)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
machine, err = app.GetMachine("test", "testmachine")
|
machine, err = app.db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(machine.ForcedTags, check.DeepEquals, StringList(sTags))
|
c.Assert(machine.ForcedTags, check.DeepEquals, StringList(sTags))
|
||||||
|
|
||||||
// assign duplicat tags, expect no errors but no doubles in DB
|
// assign duplicat tags, expect no errors but no doubles in DB
|
||||||
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||||
err = app.SetTags(machine, eTags)
|
err = app.db.SetTags(machine, eTags, app.UpdateACLRules)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
machine, err = app.GetMachine("test", "testmachine")
|
machine, err = app.db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(
|
c.Assert(
|
||||||
machine.ForcedTags,
|
machine.ForcedTags,
|
||||||
|
@ -562,7 +563,7 @@ func Test_getTags(t *testing.T) {
|
||||||
test.args.stripEmailDomain,
|
test.args.stripEmailDomain,
|
||||||
)
|
)
|
||||||
for _, valid := range gotValid {
|
for _, valid := range gotValid {
|
||||||
if !contains(test.wantValid, valid) {
|
if !util.StringOrPrefixListContains(test.wantValid, valid) {
|
||||||
t.Errorf(
|
t.Errorf(
|
||||||
"valids: getTags() = %v, want %v",
|
"valids: getTags() = %v, want %v",
|
||||||
gotValid,
|
gotValid,
|
||||||
|
@ -573,7 +574,7 @@ func Test_getTags(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, invalid := range gotInvalid {
|
for _, invalid := range gotInvalid {
|
||||||
if !contains(test.wantInvalid, invalid) {
|
if !util.StringOrPrefixListContains(test.wantInvalid, invalid) {
|
||||||
t.Errorf(
|
t.Errorf(
|
||||||
"invalids: getTags() = %v, want %v",
|
"invalids: getTags() = %v, want %v",
|
||||||
gotInvalid,
|
gotInvalid,
|
||||||
|
@ -1061,19 +1062,15 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
h *Headscale
|
db *HSDatabase
|
||||||
args args
|
args args
|
||||||
want *regexp.Regexp
|
want *regexp.Regexp
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "simple machine name generation",
|
name: "simple machine name generation",
|
||||||
h: &Headscale{
|
db: &HSDatabase{
|
||||||
cfg: &Config{
|
stripEmailDomain: true,
|
||||||
OIDC: OIDCConfig{
|
|
||||||
StripEmaildomain: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "testmachine",
|
suppliedName: "testmachine",
|
||||||
|
@ -1084,12 +1081,8 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with 53 chars",
|
name: "machine name with 53 chars",
|
||||||
h: &Headscale{
|
db: &HSDatabase{
|
||||||
cfg: &Config{
|
stripEmailDomain: true,
|
||||||
OIDC: OIDCConfig{
|
|
||||||
StripEmaildomain: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
|
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
|
||||||
|
@ -1100,12 +1093,8 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with 63 chars",
|
name: "machine name with 63 chars",
|
||||||
h: &Headscale{
|
db: &HSDatabase{
|
||||||
cfg: &Config{
|
stripEmailDomain: true,
|
||||||
OIDC: OIDCConfig{
|
|
||||||
StripEmaildomain: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
||||||
|
@ -1116,12 +1105,8 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with 64 chars",
|
name: "machine name with 64 chars",
|
||||||
h: &Headscale{
|
db: &HSDatabase{
|
||||||
cfg: &Config{
|
stripEmailDomain: true,
|
||||||
OIDC: OIDCConfig{
|
|
||||||
StripEmaildomain: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234",
|
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234",
|
||||||
|
@ -1132,12 +1117,8 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with 73 chars",
|
name: "machine name with 73 chars",
|
||||||
h: &Headscale{
|
db: &HSDatabase{
|
||||||
cfg: &Config{
|
stripEmailDomain: true,
|
||||||
OIDC: OIDCConfig{
|
|
||||||
StripEmaildomain: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123",
|
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123",
|
||||||
|
@ -1148,12 +1129,8 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with random suffix",
|
name: "machine name with random suffix",
|
||||||
h: &Headscale{
|
db: &HSDatabase{
|
||||||
cfg: &Config{
|
stripEmailDomain: true,
|
||||||
OIDC: OIDCConfig{
|
|
||||||
StripEmaildomain: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "test",
|
suppliedName: "test",
|
||||||
|
@ -1164,12 +1141,8 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with 63 chars with random suffix",
|
name: "machine name with 63 chars with random suffix",
|
||||||
h: &Headscale{
|
db: &HSDatabase{
|
||||||
cfg: &Config{
|
stripEmailDomain: true,
|
||||||
OIDC: OIDCConfig{
|
|
||||||
StripEmaildomain: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
||||||
|
@ -1181,7 +1154,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := tt.h.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix)
|
got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf(
|
t.Errorf(
|
||||||
"Headscale.GenerateGivenName() error = %v, wantErr %v",
|
"Headscale.GenerateGivenName() error = %v, wantErr %v",
|
||||||
|
@ -1239,10 +1212,10 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
||||||
err := app.LoadACLPolicyFromBytes(acl, "hujson")
|
err := app.LoadACLPolicyFromBytes(acl, "hujson")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -1255,7 +1228,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
|
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "test",
|
Hostname: "test",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
|
@ -1268,18 +1241,18 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||||
}
|
}
|
||||||
|
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
err = app.processMachineRoutes(&machine)
|
err = app.db.processMachineRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine0ByID, err := app.GetMachineByID(0)
|
machine0ByID, err := app.db.GetMachineByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.EnableAutoApprovedRoutes(machine0ByID)
|
err = app.db.EnableAutoApprovedRoutes(app.aclPolicy, machine0ByID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes, err := app.GetEnabledRoutes(machine0ByID)
|
enabledRoutes, err := app.db.GetEnabledRoutes(machine0ByID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(enabledRoutes, check.HasLen, 3)
|
c.Assert(enabledRoutes, check.HasLen, 3)
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
@ -21,16 +22,22 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
randomByteSize = 16
|
randomByteSize = 16
|
||||||
|
)
|
||||||
|
|
||||||
errEmptyOIDCCallbackParams = Error("empty OIDC callback params")
|
var (
|
||||||
errNoOIDCIDToken = Error("could not extract ID Token for OIDC callback")
|
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
|
||||||
errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain")
|
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
|
||||||
errOIDCAllowedGroups = Error("authenticated principal is not in any allowed group")
|
errOIDCAllowedDomains = errors.New(
|
||||||
errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user")
|
"authenticated principal does not match any allowed domain",
|
||||||
errOIDCInvalidMachineState = Error(
|
)
|
||||||
|
errOIDCAllowedGroups = errors.New("authenticated principal is not in any allowed group")
|
||||||
|
errOIDCAllowedUsers = errors.New(
|
||||||
|
"authenticated principal does not match any allowed user",
|
||||||
|
)
|
||||||
|
errOIDCInvalidMachineState = errors.New(
|
||||||
"requested machine state key expired before authorisation completed",
|
"requested machine state key expired before authorisation completed",
|
||||||
)
|
)
|
||||||
errOIDCNodeKeyMissing = Error("could not get node key from cache")
|
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDTokenClaims struct {
|
type IDTokenClaims struct {
|
||||||
|
@ -94,7 +101,7 @@ func (h *Headscale) RegisterOIDC(
|
||||||
Bool("ok", ok).
|
Bool("ok", ok).
|
||||||
Msg("Received oidc register call")
|
Msg("Received oidc register call")
|
||||||
|
|
||||||
if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
|
if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
|
||||||
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
|
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
@ -115,7 +122,7 @@ func (h *Headscale) RegisterOIDC(
|
||||||
// the template and log an error.
|
// the template and log an error.
|
||||||
var nodeKey key.NodePublic
|
var nodeKey key.NodePublic
|
||||||
err := nodeKey.UnmarshalText(
|
err := nodeKey.UnmarshalText(
|
||||||
[]byte(NodePublicKeyEnsurePrefix(nodeKeyStr)),
|
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)),
|
||||||
)
|
)
|
||||||
|
|
||||||
if !ok || nodeKeyStr == "" || err != nil {
|
if !ok || nodeKeyStr == "" || err != nil {
|
||||||
|
@ -149,7 +156,11 @@ func (h *Headscale) RegisterOIDC(
|
||||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||||
|
|
||||||
// place the node key into the state cache, so it can be retrieved later
|
// place the node key into the state cache, so it can be retrieved later
|
||||||
h.registrationCache.Set(stateStr, NodePublicKeyStripPrefix(nodeKey), registerCacheExpiration)
|
h.registrationCache.Set(
|
||||||
|
stateStr,
|
||||||
|
util.NodePublicKeyStripPrefix(nodeKey),
|
||||||
|
registerCacheExpiration,
|
||||||
|
)
|
||||||
|
|
||||||
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
||||||
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
|
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
|
||||||
|
@ -406,7 +417,7 @@ func validateOIDCAllowedDomains(
|
||||||
) error {
|
) error {
|
||||||
if len(allowedDomains) > 0 {
|
if len(allowedDomains) > 0 {
|
||||||
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
|
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
|
||||||
!IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
|
!util.IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
|
||||||
log.Error().Msg("authenticated principal does not match any allowed domain")
|
log.Error().Msg("authenticated principal does not match any allowed domain")
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
writer.WriteHeader(http.StatusBadRequest)
|
||||||
|
@ -436,7 +447,7 @@ func validateOIDCAllowedGroups(
|
||||||
) error {
|
) error {
|
||||||
if len(allowedGroups) > 0 {
|
if len(allowedGroups) > 0 {
|
||||||
for _, group := range allowedGroups {
|
for _, group := range allowedGroups {
|
||||||
if IsStringInSlice(claims.Groups, group) {
|
if util.IsStringInSlice(claims.Groups, group) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -466,7 +477,7 @@ func validateOIDCAllowedUsers(
|
||||||
claims *IDTokenClaims,
|
claims *IDTokenClaims,
|
||||||
) error {
|
) error {
|
||||||
if len(allowedUsers) > 0 &&
|
if len(allowedUsers) > 0 &&
|
||||||
!IsStringInSlice(allowedUsers, claims.Email) {
|
!util.IsStringInSlice(allowedUsers, claims.Email) {
|
||||||
log.Error().Msg("authenticated principal does not match any allowed user")
|
log.Error().Msg("authenticated principal does not match any allowed user")
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
writer.WriteHeader(http.StatusBadRequest)
|
||||||
|
@ -531,7 +542,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
||||||
}
|
}
|
||||||
|
|
||||||
err := nodeKey.UnmarshalText(
|
err := nodeKey.UnmarshalText(
|
||||||
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
|
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -555,7 +566,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
||||||
// The error is not important, because if it does not
|
// The error is not important, because if it does not
|
||||||
// exist, then this is a new machine and we will move
|
// exist, then this is a new machine and we will move
|
||||||
// on to registration.
|
// on to registration.
|
||||||
machine, _ := h.GetMachineByNodeKey(nodeKey)
|
machine, _ := h.db.GetMachineByNodeKey(nodeKey)
|
||||||
|
|
||||||
if machine != nil {
|
if machine != nil {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
|
@ -563,7 +574,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
Msg("machine already registered, reauthenticating")
|
Msg("machine already registered, reauthenticating")
|
||||||
|
|
||||||
err := h.RefreshMachine(machine, expiry)
|
err := h.db.RefreshMachine(machine, expiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -653,9 +664,9 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
userName string,
|
userName string,
|
||||||
) (*User, error) {
|
) (*User, error) {
|
||||||
user, err := h.GetUser(userName)
|
user, err := h.db.GetUser(userName)
|
||||||
if errors.Is(err, ErrUserNotFound) {
|
if errors.Is(err, ErrUserNotFound) {
|
||||||
user, err = h.CreateUser(userName)
|
user, err = h.db.CreateUser(userName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -702,7 +713,9 @@ func (h *Headscale) registerMachineForOIDCCallback(
|
||||||
nodeKey *key.NodePublic,
|
nodeKey *key.NodePublic,
|
||||||
expiry time.Time,
|
expiry time.Time,
|
||||||
) error {
|
) error {
|
||||||
if _, err := h.RegisterMachineFromAuthCallback(
|
if _, err := h.db.RegisterMachineFromAuthCallback(
|
||||||
|
// TODO(kradalby): find a better way to use the cache across modules
|
||||||
|
h.registrationCache,
|
||||||
nodeKey.String(),
|
nodeKey.String(),
|
||||||
user.Name,
|
user.Name,
|
||||||
&expiry,
|
&expiry,
|
||||||
|
|
|
@ -10,16 +10,17 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
ErrPreAuthKeyNotFound = Error("AuthKey not found")
|
ErrPreAuthKeyNotFound = errors.New("AuthKey not found")
|
||||||
ErrPreAuthKeyExpired = Error("AuthKey expired")
|
ErrPreAuthKeyExpired = errors.New("AuthKey expired")
|
||||||
ErrSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
|
ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used")
|
||||||
ErrUserMismatch = Error("user mismatch")
|
ErrUserMismatch = errors.New("user mismatch")
|
||||||
ErrPreAuthKeyACLTagInvalid = Error("AuthKey tag is invalid")
|
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
|
||||||
)
|
)
|
||||||
|
|
||||||
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
||||||
|
@ -45,26 +46,30 @@ type PreAuthKeyACLTag struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||||
func (h *Headscale) CreatePreAuthKey(
|
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
userName string,
|
userName string,
|
||||||
reusable bool,
|
reusable bool,
|
||||||
ephemeral bool,
|
ephemeral bool,
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
aclTags []string,
|
aclTags []string,
|
||||||
) (*PreAuthKey, error) {
|
) (*PreAuthKey, error) {
|
||||||
user, err := h.GetUser(userName)
|
user, err := hsdb.GetUser(userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tag := range aclTags {
|
for _, tag := range aclTags {
|
||||||
if !strings.HasPrefix(tag, "tag:") {
|
if !strings.HasPrefix(tag, "tag:") {
|
||||||
return nil, fmt.Errorf("%w: '%s' did not begin with 'tag:'", ErrPreAuthKeyACLTagInvalid, tag)
|
return nil, fmt.Errorf(
|
||||||
|
"%w: '%s' did not begin with 'tag:'",
|
||||||
|
ErrPreAuthKeyACLTagInvalid,
|
||||||
|
tag,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
kstr, err := h.generateKey()
|
kstr, err := hsdb.generateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -79,7 +84,7 @@ func (h *Headscale) CreatePreAuthKey(
|
||||||
Expiration: expiration,
|
Expiration: expiration,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.db.Transaction(func(db *gorm.DB) error {
|
err = hsdb.db.Transaction(func(db *gorm.DB) error {
|
||||||
if err := db.Save(&key).Error; err != nil {
|
if err := db.Save(&key).Error; err != nil {
|
||||||
return fmt.Errorf("failed to create key in the database: %w", err)
|
return fmt.Errorf("failed to create key in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -111,14 +116,14 @@ func (h *Headscale) CreatePreAuthKey(
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||||
func (h *Headscale) ListPreAuthKeys(userName string) ([]PreAuthKey, error) {
|
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) {
|
||||||
user, err := h.GetUser(userName)
|
user, err := hsdb.GetUser(userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := []PreAuthKey{}
|
keys := []PreAuthKey{}
|
||||||
if err := h.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,8 +131,8 @@ func (h *Headscale) ListPreAuthKeys(userName string) ([]PreAuthKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPreAuthKey returns a PreAuthKey for a given key.
|
// GetPreAuthKey returns a PreAuthKey for a given key.
|
||||||
func (h *Headscale) GetPreAuthKey(user string, key string) (*PreAuthKey, error) {
|
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, error) {
|
||||||
pak, err := h.checkKeyValidity(key)
|
pak, err := hsdb.checkKeyValidity(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -141,8 +146,8 @@ func (h *Headscale) GetPreAuthKey(user string, key string) (*PreAuthKey, error)
|
||||||
|
|
||||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||||
// does not exist.
|
// does not exist.
|
||||||
func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error {
|
func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error {
|
||||||
return h.db.Transaction(func(db *gorm.DB) error {
|
return hsdb.db.Transaction(func(db *gorm.DB) error {
|
||||||
if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil {
|
if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
@ -156,8 +161,8 @@ func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||||
func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
|
func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error {
|
||||||
if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,9 +170,9 @@ func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsePreAuthKey marks a PreAuthKey as used.
|
// UsePreAuthKey marks a PreAuthKey as used.
|
||||||
func (h *Headscale) UsePreAuthKey(k *PreAuthKey) error {
|
func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error {
|
||||||
k.Used = true
|
k.Used = true
|
||||||
if err := h.db.Save(k).Error; err != nil {
|
if err := hsdb.db.Save(k).Error; err != nil {
|
||||||
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,9 +181,9 @@ func (h *Headscale) UsePreAuthKey(k *PreAuthKey) error {
|
||||||
|
|
||||||
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
|
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||||
// If returns no error and a PreAuthKey, it can be used.
|
// If returns no error and a PreAuthKey, it can be used.
|
||||||
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
|
func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) {
|
||||||
pak := PreAuthKey{}
|
pak := PreAuthKey{}
|
||||||
if result := h.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
) {
|
) {
|
||||||
|
@ -194,7 +199,7 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
machines := []Machine{}
|
machines := []Machine{}
|
||||||
if err := h.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil {
|
if err := hsdb.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,7 +210,7 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
|
||||||
return &pak, nil
|
return &pak, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) generateKey() (string, error) {
|
func (hsdb *HSDatabase) generateKey() (string, error) {
|
||||||
size := 24
|
size := 24
|
||||||
bytes := make([]byte, size)
|
bytes := make([]byte, size)
|
||||||
if _, err := rand.Read(bytes); err != nil {
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
@ -218,7 +223,7 @@ func (h *Headscale) generateKey() (string, error) {
|
||||||
func (key *PreAuthKey) toProto() *v1.PreAuthKey {
|
func (key *PreAuthKey) toProto() *v1.PreAuthKey {
|
||||||
protoKey := v1.PreAuthKey{
|
protoKey := v1.PreAuthKey{
|
||||||
User: key.User.Name,
|
User: key.User.Name,
|
||||||
Id: strconv.FormatUint(key.ID, Base10),
|
Id: strconv.FormatUint(key.ID, util.Base10),
|
||||||
Key: key.Key,
|
Key: key.Key,
|
||||||
Ephemeral: key.Ephemeral,
|
Ephemeral: key.Ephemeral,
|
||||||
Reusable: key.Reusable,
|
Reusable: key.Reusable,
|
||||||
|
|
|
@ -7,14 +7,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
_, err := app.CreatePreAuthKey("bogus", true, false, nil, nil)
|
_, err := app.db.CreatePreAuthKey("bogus", true, false, nil, nil)
|
||||||
|
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
key, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// Did we get a valid key?
|
// Did we get a valid key?
|
||||||
|
@ -24,10 +24,10 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
// Make sure the User association is populated
|
// Make sure the User association is populated
|
||||||
c.Assert(key.User.Name, check.Equals, user.Name)
|
c.Assert(key.User.Name, check.Equals, user.Name)
|
||||||
|
|
||||||
_, err = app.ListPreAuthKeys("bogus")
|
_, err = app.db.ListPreAuthKeys("bogus")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
keys, err := app.ListPreAuthKeys(user.Name)
|
keys, err := app.db.ListPreAuthKeys(user.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(keys), check.Equals, 1)
|
c.Assert(len(keys), check.Equals, 1)
|
||||||
|
|
||||||
|
@ -36,41 +36,41 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||||
user, err := app.CreateUser("test2")
|
user, err := app.db.CreateUser("test2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := app.checkKeyValidity(pak.Key)
|
key, err := app.db.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
||||||
c.Assert(key, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
|
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
|
||||||
key, err := app.checkKeyValidity("potatoKey")
|
key, err := app.db.checkKeyValidity("potatoKey")
|
||||||
c.Assert(err, check.Equals, ErrPreAuthKeyNotFound)
|
c.Assert(err, check.Equals, ErrPreAuthKeyNotFound)
|
||||||
c.Assert(key, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestValidateKeyOk(c *check.C) {
|
func (*Suite) TestValidateKeyOk(c *check.C) {
|
||||||
user, err := app.CreateUser("test3")
|
user, err := app.db.CreateUser("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := app.checkKeyValidity(pak.Key)
|
key, err := app.db.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(key.ID, check.Equals, pak.ID)
|
c.Assert(key.ID, check.Equals, pak.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||||
user, err := app.CreateUser("test4")
|
user, err := app.db.CreateUser("test4")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
|
@ -83,18 +83,18 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
key, err := app.checkKeyValidity(pak.Key)
|
key, err := app.db.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||||
c.Assert(key, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||||
user, err := app.CreateUser("test5")
|
user, err := app.db.CreateUser("test5")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
|
@ -107,30 +107,30 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
key, err := app.checkKeyValidity(pak.Key)
|
key, err := app.db.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(key.ID, check.Equals, pak.ID)
|
c.Assert(key.ID, check.Equals, pak.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
||||||
user, err := app.CreateUser("test6")
|
user, err := app.db.CreateUser("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := app.checkKeyValidity(pak.Key)
|
key, err := app.db.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(key.ID, check.Equals, pak.ID)
|
c.Assert(key.ID, check.Equals, pak.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestEphemeralKey(c *check.C) {
|
func (*Suite) TestEphemeralKey(c *check.C) {
|
||||||
user, err := app.CreateUser("test7")
|
user, err := app.db.CreateUser("test7")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, true, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, true, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
@ -145,65 +145,65 @@ func (*Suite) TestEphemeralKey(c *check.C) {
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
_, err = app.checkKeyValidity(pak.Key)
|
_, err = app.db.checkKeyValidity(pak.Key)
|
||||||
// Ephemeral keys are by definition reusable
|
// Ephemeral keys are by definition reusable
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test7", "testest")
|
_, err = app.db.GetMachine("test7", "testest")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
app.expireEphemeralNodesWorker()
|
app.expireEphemeralNodesWorker()
|
||||||
|
|
||||||
// The machine record should have been deleted
|
// The machine record should have been deleted
|
||||||
_, err = app.GetMachine("test7", "testest")
|
_, err = app.db.GetMachine("test7", "testest")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||||
user, err := app.CreateUser("test3")
|
user, err := app.db.CreateUser("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pak.Expiration, check.IsNil)
|
c.Assert(pak.Expiration, check.IsNil)
|
||||||
|
|
||||||
err = app.ExpirePreAuthKey(pak)
|
err = app.db.ExpirePreAuthKey(pak)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pak.Expiration, check.NotNil)
|
c.Assert(pak.Expiration, check.NotNil)
|
||||||
|
|
||||||
key, err := app.checkKeyValidity(pak.Key)
|
key, err := app.db.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
||||||
c.Assert(key, check.IsNil)
|
c.Assert(key, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||||
user, err := app.CreateUser("test6")
|
user, err := app.db.CreateUser("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
pak.Used = true
|
pak.Used = true
|
||||||
app.db.Save(&pak)
|
app.db.db.Save(&pak)
|
||||||
|
|
||||||
_, err = app.checkKeyValidity(pak.Key)
|
_, err = app.db.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
||||||
user, err := app.CreateUser("test8")
|
user, err := app.db.CreateUser("test8")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
|
_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
|
||||||
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
||||||
|
|
||||||
tags := []string{"tag:test1", "tag:test2"}
|
tags := []string{"tag:test1", "tag:test2"}
|
||||||
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
||||||
_, err = app.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
|
_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
listedPaks, err := app.ListPreAuthKeys("test8")
|
listedPaks, err := app.db.ListPreAuthKeys("test8")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags)
|
c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -82,7 +83,7 @@ func (h *Headscale) KeyHandler(
|
||||||
// Old clients don't send a 'v' parameter, so we send the legacy public key
|
// Old clients don't send a 'v' parameter, so we send the legacy public key
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
_, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public())))
|
_, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey.Public())))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -102,7 +103,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
machine, err := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
|
machine, err := h.db.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
// If the machine has AuthKey set, handle registration via PreAuthKeys
|
// If the machine has AuthKey set, handle registration via PreAuthKeys
|
||||||
if registerRequest.Auth.AuthKey != "" {
|
if registerRequest.Auth.AuthKey != "" {
|
||||||
|
@ -120,7 +121,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
// is that the client will hammer headscale with requests until it gets a
|
// is that the client will hammer headscale with requests until it gets a
|
||||||
// successful RegisterResponse.
|
// successful RegisterResponse.
|
||||||
if registerRequest.Followup != "" {
|
if registerRequest.Followup != "" {
|
||||||
if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
|
if _, ok := h.registrationCache.Get(util.NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||||
|
@ -152,7 +153,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
Bool("noise", isNoise).
|
Bool("noise", isNoise).
|
||||||
Msg("New machine not yet in the database")
|
Msg("New machine not yet in the database")
|
||||||
|
|
||||||
givenName, err := h.GenerateGivenName(
|
givenName, err := h.db.GenerateGivenName(
|
||||||
machineKey.String(),
|
machineKey.String(),
|
||||||
registerRequest.Hostinfo.Hostname,
|
registerRequest.Hostinfo.Hostname,
|
||||||
)
|
)
|
||||||
|
@ -171,10 +172,10 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
// We create the machine and then keep it around until a callback
|
// We create the machine and then keep it around until a callback
|
||||||
// happens
|
// happens
|
||||||
newMachine := Machine{
|
newMachine := Machine{
|
||||||
MachineKey: MachinePublicKeyStripPrefix(machineKey),
|
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
||||||
Hostname: registerRequest.Hostinfo.Hostname,
|
Hostname: registerRequest.Hostinfo.Hostname,
|
||||||
GivenName: givenName,
|
GivenName: givenName,
|
||||||
NodeKey: NodePublicKeyStripPrefix(registerRequest.NodeKey),
|
NodeKey: util.NodePublicKeyStripPrefix(registerRequest.NodeKey),
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
Expiry: &time.Time{},
|
Expiry: &time.Time{},
|
||||||
}
|
}
|
||||||
|
@ -210,11 +211,11 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
// So if we have a not valid MachineKey (but we were able to fetch the machine with the NodeKeys), we update it.
|
// So if we have a not valid MachineKey (but we were able to fetch the machine with the NodeKeys), we update it.
|
||||||
var storedMachineKey key.MachinePublic
|
var storedMachineKey key.MachinePublic
|
||||||
err = storedMachineKey.UnmarshalText(
|
err = storedMachineKey.UnmarshalText(
|
||||||
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||||
)
|
)
|
||||||
if err != nil || storedMachineKey.IsZero() {
|
if err != nil || storedMachineKey.IsZero() {
|
||||||
machine.MachineKey = MachinePublicKeyStripPrefix(machineKey)
|
machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey)
|
||||||
if err := h.db.Save(&machine).Error; err != nil {
|
if err := h.db.db.Save(&machine).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Str("func", "RegistrationHandler").
|
Str("func", "RegistrationHandler").
|
||||||
|
@ -231,7 +232,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
// - Trying to log out (sending a expiry in the past)
|
// - Trying to log out (sending a expiry in the past)
|
||||||
// - A valid, registered machine, looking for /map
|
// - A valid, registered machine, looking for /map
|
||||||
// - Expired machine wanting to reauthenticate
|
// - Expired machine wanting to reauthenticate
|
||||||
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) {
|
if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.NodeKey) {
|
||||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
||||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||||
if !registerRequest.Expiry.IsZero() &&
|
if !registerRequest.Expiry.IsZero() &&
|
||||||
|
@ -251,7 +252,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
}
|
}
|
||||||
|
|
||||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||||
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
|
if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
|
||||||
!machine.isExpired() {
|
!machine.isExpired() {
|
||||||
h.handleMachineRefreshKeyCommon(
|
h.handleMachineRefreshKeyCommon(
|
||||||
writer,
|
writer,
|
||||||
|
@ -282,9 +283,9 @@ func (h *Headscale) handleRegisterCommon(
|
||||||
// we need to make sure the NodeKey matches the one in the request
|
// we need to make sure the NodeKey matches the one in the request
|
||||||
// TODO(juan): What happens when using fast user switching between two
|
// TODO(juan): What happens when using fast user switching between two
|
||||||
// headscale-managed tailnets?
|
// headscale-managed tailnets?
|
||||||
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||||
h.registrationCache.Set(
|
h.registrationCache.Set(
|
||||||
NodePublicKeyStripPrefix(registerRequest.NodeKey),
|
util.NodePublicKeyStripPrefix(registerRequest.NodeKey),
|
||||||
*machine,
|
*machine,
|
||||||
registerCacheExpiration,
|
registerCacheExpiration,
|
||||||
)
|
)
|
||||||
|
@ -311,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
|
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
|
|
||||||
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
|
pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -372,13 +373,13 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||||
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
||||||
|
|
||||||
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
nodeKey := util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||||
|
|
||||||
// retrieve machine information if it exist
|
// retrieve machine information if it exist
|
||||||
// The error is not important, because if it does not
|
// The error is not important, because if it does not
|
||||||
// exist, then this is a new machine and we will move
|
// exist, then this is a new machine and we will move
|
||||||
// on to registration.
|
// on to registration.
|
||||||
machine, _ := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
|
machine, _ := h.db.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
|
||||||
if machine != nil {
|
if machine != nil {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -388,7 +389,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
|
|
||||||
machine.NodeKey = nodeKey
|
machine.NodeKey = nodeKey
|
||||||
machine.AuthKeyID = uint(pak.ID)
|
machine.AuthKeyID = uint(pak.ID)
|
||||||
err := h.RefreshMachine(machine, registerRequest.Expiry)
|
err := h.db.RefreshMachine(machine, registerRequest.Expiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -403,7 +404,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
aclTags := pak.toProto().AclTags
|
aclTags := pak.toProto().AclTags
|
||||||
if len(aclTags) > 0 {
|
if len(aclTags) > 0 {
|
||||||
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
||||||
err = h.SetTags(machine, aclTags)
|
err = h.db.SetTags(machine, aclTags, h.UpdateACLRules)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -420,7 +421,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
} else {
|
} else {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
givenName, err := h.GenerateGivenName(MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname)
|
givenName, err := h.db.GenerateGivenName(util.MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -436,7 +437,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
Hostname: registerRequest.Hostinfo.Hostname,
|
Hostname: registerRequest.Hostinfo.Hostname,
|
||||||
GivenName: givenName,
|
GivenName: givenName,
|
||||||
UserID: pak.User.ID,
|
UserID: pak.User.ID,
|
||||||
MachineKey: MachinePublicKeyStripPrefix(machineKey),
|
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
Expiry: ®isterRequest.Expiry,
|
Expiry: ®isterRequest.Expiry,
|
||||||
NodeKey: nodeKey,
|
NodeKey: nodeKey,
|
||||||
|
@ -445,7 +446,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
ForcedTags: pak.toProto().AclTags,
|
ForcedTags: pak.toProto().AclTags,
|
||||||
}
|
}
|
||||||
|
|
||||||
machine, err = h.RegisterMachine(
|
machine, err = h.db.RegisterMachine(
|
||||||
machineToRegister,
|
machineToRegister,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -462,7 +463,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.UsePreAuthKey(pak)
|
err = h.db.UsePreAuthKey(pak)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -591,7 +592,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
Msg("Client requested logout")
|
Msg("Client requested logout")
|
||||||
|
|
||||||
err := h.ExpireMachine(&machine)
|
err := h.db.ExpireMachine(&machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -634,7 +635,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
||||||
}
|
}
|
||||||
|
|
||||||
if machine.isEphemeral() {
|
if machine.isEphemeral() {
|
||||||
err = h.HardDeleteMachine(&machine)
|
err = h.db.HardDeleteMachine(&machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -720,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||||
Bool("noise", isNoise).
|
Bool("noise", isNoise).
|
||||||
Str("machine", machine.Hostname).
|
Str("machine", machine.Hostname).
|
||||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||||
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||||
|
|
||||||
if err := h.db.Save(&machine).Error; err != nil {
|
if err := h.db.db.Save(&machine).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
@ -29,10 +30,10 @@ func (h *Headscale) handlePollCommon(
|
||||||
) {
|
) {
|
||||||
machine.Hostname = mapRequest.Hostinfo.Hostname
|
machine.Hostname = mapRequest.Hostinfo.Hostname
|
||||||
machine.HostInfo = HostInfo(*mapRequest.Hostinfo)
|
machine.HostInfo = HostInfo(*mapRequest.Hostinfo)
|
||||||
machine.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
|
machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
err := h.processMachineRoutes(machine)
|
err := h.db.processMachineRoutes(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -53,7 +54,7 @@ func (h *Headscale) handlePollCommon(
|
||||||
}
|
}
|
||||||
|
|
||||||
// update routes with peer information
|
// update routes with peer information
|
||||||
err = h.EnableAutoApprovedRoutes(machine)
|
err = h.db.EnableAutoApprovedRoutes(h.aclPolicy, machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -77,7 +78,7 @@ func (h *Headscale) handlePollCommon(
|
||||||
machine.LastSeen = &now
|
machine.LastSeen = &now
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Updates(machine).Error; err != nil {
|
if err := h.db.db.Updates(machine).Error; err != nil {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
|
@ -325,7 +326,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
// command line, but then overwritten.
|
// command line, but then overwritten.
|
||||||
err = h.UpdateMachineFromDatabase(machine)
|
err = h.db.UpdateMachineFromDatabase(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
@ -346,7 +347,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
Set(float64(now.Unix()))
|
Set(float64(now.Unix()))
|
||||||
machine.LastSuccessfulUpdate = &now
|
machine.LastSuccessfulUpdate = &now
|
||||||
|
|
||||||
err = h.TouchMachine(machine)
|
err = h.db.TouchMachine(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
@ -409,7 +410,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
// command line, but then overwritten.
|
// command line, but then overwritten.
|
||||||
err = h.UpdateMachineFromDatabase(machine)
|
err = h.db.UpdateMachineFromDatabase(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
@ -425,7 +426,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
}
|
}
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
machine.LastSeen = &now
|
machine.LastSeen = &now
|
||||||
err = h.TouchMachine(machine)
|
err = h.db.TouchMachine(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
@ -456,7 +457,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname).
|
updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname).
|
||||||
Inc()
|
Inc()
|
||||||
|
|
||||||
if h.isOutdated(machine) {
|
if h.db.isOutdated(machine, h.getLastStateChange()) {
|
||||||
var lastUpdate time.Time
|
var lastUpdate time.Time
|
||||||
if machine.LastSuccessfulUpdate != nil {
|
if machine.LastSuccessfulUpdate != nil {
|
||||||
lastUpdate = *machine.LastSuccessfulUpdate
|
lastUpdate = *machine.LastSuccessfulUpdate
|
||||||
|
@ -524,7 +525,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
// command line, but then overwritten.
|
// command line, but then overwritten.
|
||||||
err = h.UpdateMachineFromDatabase(machine)
|
err = h.db.UpdateMachineFromDatabase(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
@ -544,7 +545,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
Set(float64(now.Unix()))
|
Set(float64(now.Unix()))
|
||||||
machine.LastSuccessfulUpdate = &now
|
machine.LastSuccessfulUpdate = &now
|
||||||
|
|
||||||
err = h.TouchMachine(machine)
|
err = h.db.TouchMachine(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
@ -578,7 +579,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
// TODO: Abstract away all the database calls, this can cause race conditions
|
// TODO: Abstract away all the database calls, this can cause race conditions
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
// command line, but then overwritten.
|
// command line, but then overwritten.
|
||||||
err := h.UpdateMachineFromDatabase(machine)
|
err := h.db.UpdateMachineFromDatabase(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
@ -594,7 +595,7 @@ func (h *Headscale) pollNetMapStream(
|
||||||
}
|
}
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
machine.LastSeen = &now
|
machine.LastSeen = &now
|
||||||
err = h.TouchMachine(machine)
|
err = h.db.TouchMachine(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"tailscale.com/smallzstd"
|
"tailscale.com/smallzstd"
|
||||||
|
@ -27,7 +28,7 @@ func (h *Headscale) getMapResponseData(
|
||||||
}
|
}
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
var machineKey key.MachinePublic
|
||||||
err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -50,11 +51,16 @@ func (h *Headscale) getMapKeepAliveResponseData(
|
||||||
}
|
}
|
||||||
|
|
||||||
if isNoise {
|
if isNoise {
|
||||||
return h.marshalMapResponse(keepAliveResponse, key.MachinePublic{}, mapRequest.Compress, isNoise)
|
return h.marshalMapResponse(
|
||||||
|
keepAliveResponse,
|
||||||
|
key.MachinePublic{},
|
||||||
|
mapRequest.Compress,
|
||||||
|
isNoise,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
var machineKey key.MachinePublic
|
||||||
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -104,7 +110,7 @@ func (h *Headscale) marshalMapResponse(
|
||||||
}
|
}
|
||||||
|
|
||||||
var respBody []byte
|
var respBody []byte
|
||||||
if compression == ZstdCompression {
|
if compression == util.ZstdCompression {
|
||||||
respBody = zstdEncode(jsonBody)
|
respBody = zstdEncode(jsonBody)
|
||||||
if !isNoise { // if legacy protocol
|
if !isNoise { // if legacy protocol
|
||||||
respBody = h.privateKey.SealTo(machineKey, respBody)
|
respBody = h.privateKey.SealTo(machineKey, respBody)
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
@ -32,7 +33,7 @@ func (h *Headscale) RegistrationHandler(
|
||||||
body, _ := io.ReadAll(req.Body)
|
body, _ := io.ReadAll(req.Body)
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
var machineKey key.MachinePublic
|
||||||
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
|
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -44,7 +45,7 @@ func (h *Headscale) RegistrationHandler(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
registerRequest := tailcfg.RegisterRequest{}
|
registerRequest := tailcfg.RegisterRequest{}
|
||||||
err = decode(body, ®isterRequest, &machineKey, h.privateKey)
|
err = util.DecodeAndUnmarshalNaCl(body, ®isterRequest, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -44,7 +45,7 @@ func (h *Headscale) PollNetMapHandler(
|
||||||
body, _ := io.ReadAll(req.Body)
|
body, _ := io.ReadAll(req.Body)
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
var machineKey key.MachinePublic
|
||||||
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
|
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
|
@ -56,7 +57,7 @@ func (h *Headscale) PollNetMapHandler(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
mapRequest := tailcfg.MapRequest{}
|
mapRequest := tailcfg.MapRequest{}
|
||||||
err = decode(body, &mapRequest, &machineKey, h.privateKey)
|
err = util.DecodeAndUnmarshalNaCl(body, &mapRequest, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
|
@ -67,7 +68,7 @@ func (h *Headscale) PollNetMapHandler(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
machine, err := h.GetMachineByMachineKey(machineKey)
|
machine, err := h.db.GetMachineByMachineKey(machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
|
|
|
@ -48,7 +48,11 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
||||||
|
|
||||||
ns.nodeKey = mapRequest.NodeKey
|
ns.nodeKey = mapRequest.NodeKey
|
||||||
|
|
||||||
machine, err := ns.headscale.GetMachineByAnyKey(ns.conn.Peer(), mapRequest.NodeKey, key.NodePublic{})
|
machine, err := ns.headscale.db.GetMachineByAnyKey(
|
||||||
|
ns.conn.Peer(),
|
||||||
|
mapRequest.NodeKey,
|
||||||
|
key.NodePublic{},
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
|
|
|
@ -11,11 +11,8 @@ import (
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
ErrRouteIsNotAvailable = Error("route is not available")
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||||
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
|
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
|
||||||
ExitRouteV6 = netip.MustParsePrefix("::/0")
|
ExitRouteV6 = netip.MustParsePrefix("::/0")
|
||||||
)
|
)
|
||||||
|
@ -51,9 +48,9 @@ func (rs Routes) toPrefixes() []netip.Prefix {
|
||||||
return prefixes
|
return prefixes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) GetRoutes() ([]Route, error) {
|
func (hsdb *HSDatabase) GetRoutes() ([]Route, error) {
|
||||||
var routes []Route
|
var routes []Route
|
||||||
err := h.db.Preload("Machine").Find(&routes).Error
|
err := hsdb.db.Preload("Machine").Find(&routes).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -61,9 +58,9 @@ func (h *Headscale) GetRoutes() ([]Route, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) GetMachineRoutes(m *Machine) ([]Route, error) {
|
func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) {
|
||||||
var routes []Route
|
var routes []Route
|
||||||
err := h.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("machine_id = ?", m.ID).
|
Where("machine_id = ?", m.ID).
|
||||||
Find(&routes).Error
|
Find(&routes).Error
|
||||||
|
@ -74,9 +71,9 @@ func (h *Headscale) GetMachineRoutes(m *Machine) ([]Route, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) GetRoute(id uint64) (*Route, error) {
|
func (hsdb *HSDatabase) GetRoute(id uint64) (*Route, error) {
|
||||||
var route Route
|
var route Route
|
||||||
err := h.db.Preload("Machine").First(&route, id).Error
|
err := hsdb.db.Preload("Machine").First(&route, id).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -84,8 +81,8 @@ func (h *Headscale) GetRoute(id uint64) (*Route, error) {
|
||||||
return &route, nil
|
return &route, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) EnableRoute(id uint64) error {
|
func (hsdb *HSDatabase) EnableRoute(id uint64) error {
|
||||||
route, err := h.GetRoute(id)
|
route, err := hsdb.GetRoute(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -94,14 +91,14 @@ func (h *Headscale) EnableRoute(id uint64) error {
|
||||||
// be enabled at the same time, as per
|
// be enabled at the same time, as per
|
||||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||||
if route.isExitRoute() {
|
if route.isExitRoute() {
|
||||||
return h.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String())
|
return hsdb.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String())
|
return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) DisableRoute(id uint64) error {
|
func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||||
route, err := h.GetRoute(id)
|
route, err := hsdb.GetRoute(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -112,15 +109,15 @@ func (h *Headscale) DisableRoute(id uint64) error {
|
||||||
if !route.isExitRoute() {
|
if !route.isExitRoute() {
|
||||||
route.Enabled = false
|
route.Enabled = false
|
||||||
route.IsPrimary = false
|
route.IsPrimary = false
|
||||||
err = h.db.Save(route).Error
|
err = hsdb.db.Save(route).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.handlePrimarySubnetFailover()
|
return hsdb.handlePrimarySubnetFailover()
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := h.GetMachineRoutes(&route.Machine)
|
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -129,18 +126,18 @@ func (h *Headscale) DisableRoute(id uint64) error {
|
||||||
if routes[i].isExitRoute() {
|
if routes[i].isExitRoute() {
|
||||||
routes[i].Enabled = false
|
routes[i].Enabled = false
|
||||||
routes[i].IsPrimary = false
|
routes[i].IsPrimary = false
|
||||||
err = h.db.Save(&routes[i]).Error
|
err = hsdb.db.Save(&routes[i]).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.handlePrimarySubnetFailover()
|
return hsdb.handlePrimarySubnetFailover()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) DeleteRoute(id uint64) error {
|
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||||
route, err := h.GetRoute(id)
|
route, err := hsdb.GetRoute(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -149,14 +146,14 @@ func (h *Headscale) DeleteRoute(id uint64) error {
|
||||||
// be enabled at the same time, as per
|
// be enabled at the same time, as per
|
||||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||||
if !route.isExitRoute() {
|
if !route.isExitRoute() {
|
||||||
if err := h.db.Unscoped().Delete(&route).Error; err != nil {
|
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.handlePrimarySubnetFailover()
|
return hsdb.handlePrimarySubnetFailover()
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := h.GetMachineRoutes(&route.Machine)
|
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -168,32 +165,32 @@ func (h *Headscale) DeleteRoute(id uint64) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
|
if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.handlePrimarySubnetFailover()
|
return hsdb.handlePrimarySubnetFailover()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) DeleteMachineRoutes(m *Machine) error {
|
func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error {
|
||||||
routes, err := h.GetMachineRoutes(m)
|
routes, err := hsdb.GetMachineRoutes(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range routes {
|
for i := range routes {
|
||||||
if err := h.db.Unscoped().Delete(&routes[i]).Error; err != nil {
|
if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.handlePrimarySubnetFailover()
|
return hsdb.handlePrimarySubnetFailover()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isUniquePrefix returns if there is another machine providing the same route already.
|
// isUniquePrefix returns if there is another machine providing the same route already.
|
||||||
func (h *Headscale) isUniquePrefix(route Route) bool {
|
func (hsdb *HSDatabase) isUniquePrefix(route Route) bool {
|
||||||
var count int64
|
var count int64
|
||||||
h.db.
|
hsdb.db.
|
||||||
Model(&Route{}).
|
Model(&Route{}).
|
||||||
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
||||||
route.Prefix,
|
route.Prefix,
|
||||||
|
@ -203,9 +200,9 @@ func (h *Headscale) isUniquePrefix(route Route) bool {
|
||||||
return count == 0
|
return count == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
|
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
|
||||||
var route Route
|
var route Route
|
||||||
err := h.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true).
|
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true).
|
||||||
First(&route).Error
|
First(&route).Error
|
||||||
|
@ -222,9 +219,9 @@ func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
|
||||||
|
|
||||||
// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
||||||
// Exit nodes are not considered for this, as they are never marked as Primary.
|
// Exit nodes are not considered for this, as they are never marked as Primary.
|
||||||
func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
|
func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
|
||||||
var routes []Route
|
var routes []Route
|
||||||
err := h.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true).
|
Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true).
|
||||||
Find(&routes).Error
|
Find(&routes).Error
|
||||||
|
@ -235,9 +232,9 @@ func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) processMachineRoutes(machine *Machine) error {
|
func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error {
|
||||||
currentRoutes := []Route{}
|
currentRoutes := []Route{}
|
||||||
err := h.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error
|
err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -251,7 +248,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error {
|
||||||
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
|
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
|
||||||
if !route.Advertised {
|
if !route.Advertised {
|
||||||
currentRoutes[pos].Advertised = true
|
currentRoutes[pos].Advertised = true
|
||||||
err := h.db.Save(¤tRoutes[pos]).Error
|
err := hsdb.db.Save(¤tRoutes[pos]).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -260,7 +257,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error {
|
||||||
} else if route.Advertised {
|
} else if route.Advertised {
|
||||||
currentRoutes[pos].Advertised = false
|
currentRoutes[pos].Advertised = false
|
||||||
currentRoutes[pos].Enabled = false
|
currentRoutes[pos].Enabled = false
|
||||||
err := h.db.Save(¤tRoutes[pos]).Error
|
err := hsdb.db.Save(¤tRoutes[pos]).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -275,7 +272,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error {
|
||||||
Advertised: true,
|
Advertised: true,
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
}
|
}
|
||||||
err := h.db.Create(&route).Error
|
err := hsdb.db.Create(&route).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -285,10 +282,10 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) handlePrimarySubnetFailover() error {
|
func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||||
// first, get all the enabled routes
|
// first, get all the enabled routes
|
||||||
var routes []Route
|
var routes []Route
|
||||||
err := h.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("advertised = ? AND enabled = ?", true, true).
|
Where("advertised = ? AND enabled = ?", true, true).
|
||||||
Find(&routes).Error
|
Find(&routes).Error
|
||||||
|
@ -303,14 +300,14 @@ func (h *Headscale) handlePrimarySubnetFailover() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !route.IsPrimary {
|
if !route.IsPrimary {
|
||||||
_, err := h.getPrimaryRoute(netip.Prefix(route.Prefix))
|
_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
|
||||||
if h.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
|
if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("prefix", netip.Prefix(route.Prefix).String()).
|
Str("prefix", netip.Prefix(route.Prefix).String()).
|
||||||
Str("machine", route.Machine.GivenName).
|
Str("machine", route.Machine.GivenName).
|
||||||
Msg("Setting primary route")
|
Msg("Setting primary route")
|
||||||
routes[pos].IsPrimary = true
|
routes[pos].IsPrimary = true
|
||||||
err := h.db.Save(&routes[pos]).Error
|
err := hsdb.db.Save(&routes[pos]).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error marking route as primary")
|
log.Error().Err(err).Msg("error marking route as primary")
|
||||||
|
|
||||||
|
@ -336,7 +333,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error {
|
||||||
|
|
||||||
// find a new primary route
|
// find a new primary route
|
||||||
var newPrimaryRoutes []Route
|
var newPrimaryRoutes []Route
|
||||||
err := h.db.
|
err := hsdb.db.
|
||||||
Preload("Machine").
|
Preload("Machine").
|
||||||
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
||||||
route.Prefix,
|
route.Prefix,
|
||||||
|
@ -375,7 +372,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error {
|
||||||
|
|
||||||
// disable the old primary route
|
// disable the old primary route
|
||||||
routes[pos].IsPrimary = false
|
routes[pos].IsPrimary = false
|
||||||
err = h.db.Save(&routes[pos]).Error
|
err = hsdb.db.Save(&routes[pos]).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error disabling old primary route")
|
log.Error().Err(err).Msg("error disabling old primary route")
|
||||||
|
|
||||||
|
@ -384,7 +381,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error {
|
||||||
|
|
||||||
// enable the new primary route
|
// enable the new primary route
|
||||||
newPrimaryRoute.IsPrimary = true
|
newPrimaryRoute.IsPrimary = true
|
||||||
err = h.db.Save(&newPrimaryRoute).Error
|
err = hsdb.db.Save(&newPrimaryRoute).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error enabling new primary route")
|
log.Error().Err(err).Msg("error enabling new primary route")
|
||||||
|
|
||||||
|
@ -396,7 +393,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if routesChanged {
|
if routesChanged {
|
||||||
h.setLastStateChangeToNow()
|
hsdb.notifyStateChange()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -4,19 +4,20 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestGetRoutes(c *check.C) {
|
func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "test_get_route_machine")
|
_, err = app.db.GetMachine("test", "test_get_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix("10.0.0.0/24")
|
route, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
|
@ -37,30 +38,30 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo),
|
HostInfo: HostInfo(hostInfo),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
err = app.processMachineRoutes(&machine)
|
err = app.db.processMachineRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
advertisedRoutes, err := app.GetAdvertisedRoutes(&machine)
|
advertisedRoutes, err := app.db.GetAdvertisedRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(advertisedRoutes), check.Equals, 1)
|
c.Assert(len(advertisedRoutes), check.Equals, 1)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine, "192.168.0.0/24")
|
err = app.db.enableRoutes(&machine, "192.168.0.0/24")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine, "10.0.0.0/24")
|
err = app.db.enableRoutes(&machine, "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "test_enable_route_machine")
|
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -88,54 +89,54 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo),
|
HostInfo: HostInfo(hostInfo),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
err = app.processMachineRoutes(&machine)
|
err = app.db.processMachineRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
availableRoutes, err := app.GetAdvertisedRoutes(&machine)
|
availableRoutes, err := app.db.GetAdvertisedRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(availableRoutes), check.Equals, 2)
|
c.Assert(len(availableRoutes), check.Equals, 2)
|
||||||
|
|
||||||
noEnabledRoutes, err := app.GetEnabledRoutes(&machine)
|
noEnabledRoutes, err := app.db.GetEnabledRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine, "192.168.0.0/24")
|
err = app.db.enableRoutes(&machine, "192.168.0.0/24")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine, "10.0.0.0/24")
|
err = app.db.enableRoutes(&machine, "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes, err := app.GetEnabledRoutes(&machine)
|
enabledRoutes, err := app.db.GetEnabledRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes), check.Equals, 1)
|
c.Assert(len(enabledRoutes), check.Equals, 1)
|
||||||
|
|
||||||
// Adding it twice will just let it pass through
|
// Adding it twice will just let it pass through
|
||||||
err = app.enableRoutes(&machine, "10.0.0.0/24")
|
err = app.db.enableRoutes(&machine, "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enableRoutesAfterDoubleApply, err := app.GetEnabledRoutes(&machine)
|
enableRoutesAfterDoubleApply, err := app.db.GetEnabledRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine, "150.0.10.0/25")
|
err = app.db.enableRoutes(&machine, "150.0.10.0/25")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutesWithAdditionalRoute, err := app.GetEnabledRoutes(&machine)
|
enabledRoutesWithAdditionalRoute, err := app.db.GetEnabledRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "test_enable_route_machine")
|
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -162,15 +163,15 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo1),
|
HostInfo: HostInfo(hostInfo1),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine1)
|
app.db.db.Save(&machine1)
|
||||||
|
|
||||||
err = app.processMachineRoutes(&machine1)
|
err = app.db.processMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine1, route.String())
|
err = app.db.enableRoutes(&machine1, route.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine1, route2.String())
|
err = app.db.enableRoutes(&machine1, route2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
hostInfo2 := tailcfg.Hostinfo{
|
hostInfo2 := tailcfg.Hostinfo{
|
||||||
|
@ -187,39 +188,39 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
HostInfo: HostInfo(hostInfo2),
|
HostInfo: HostInfo(hostInfo2),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine2)
|
app.db.db.Save(&machine2)
|
||||||
|
|
||||||
err = app.processMachineRoutes(&machine2)
|
err = app.db.processMachineRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine2, route2.String())
|
err = app.db.enableRoutes(&machine2, route2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := app.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||||
|
|
||||||
enabledRoutes2, err := app.GetEnabledRoutes(&machine2)
|
enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
||||||
|
|
||||||
routes, err := app.getMachinePrimaryRoutes(&machine1)
|
routes, err := app.db.getMachinePrimaryRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 2)
|
c.Assert(len(routes), check.Equals, 2)
|
||||||
|
|
||||||
routes, err = app.getMachinePrimaryRoutes(&machine2)
|
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 0)
|
c.Assert(len(routes), check.Equals, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSubnetFailover(c *check.C) {
|
func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "test_enable_route_machine")
|
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(
|
prefix, err := netip.ParsePrefix(
|
||||||
|
@ -249,25 +250,25 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||||
HostInfo: HostInfo(hostInfo1),
|
HostInfo: HostInfo(hostInfo1),
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
}
|
}
|
||||||
app.db.Save(&machine1)
|
app.db.db.Save(&machine1)
|
||||||
|
|
||||||
err = app.processMachineRoutes(&machine1)
|
err = app.db.processMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine1, prefix.String())
|
err = app.db.enableRoutes(&machine1, prefix.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine1, prefix2.String())
|
err = app.db.enableRoutes(&machine1, prefix2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.handlePrimarySubnetFailover()
|
err = app.db.handlePrimarySubnetFailover()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := app.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||||
|
|
||||||
route, err := app.getPrimaryRoute(prefix)
|
route, err := app.db.getPrimaryRoute(prefix)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(route.MachineID, check.Equals, machine1.ID)
|
c.Assert(route.MachineID, check.Equals, machine1.ID)
|
||||||
|
|
||||||
|
@ -286,70 +287,70 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||||
HostInfo: HostInfo(hostInfo2),
|
HostInfo: HostInfo(hostInfo2),
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
}
|
}
|
||||||
app.db.Save(&machine2)
|
app.db.db.Save(&machine2)
|
||||||
|
|
||||||
err = app.processMachineRoutes(&machine2)
|
err = app.db.processMachineRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine2, prefix2.String())
|
err = app.db.enableRoutes(&machine2, prefix2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.handlePrimarySubnetFailover()
|
err = app.db.handlePrimarySubnetFailover()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err = app.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||||
|
|
||||||
enabledRoutes2, err := app.GetEnabledRoutes(&machine2)
|
enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
||||||
|
|
||||||
routes, err := app.getMachinePrimaryRoutes(&machine1)
|
routes, err := app.db.getMachinePrimaryRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 2)
|
c.Assert(len(routes), check.Equals, 2)
|
||||||
|
|
||||||
routes, err = app.getMachinePrimaryRoutes(&machine2)
|
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 0)
|
c.Assert(len(routes), check.Equals, 0)
|
||||||
|
|
||||||
// lets make machine1 lastseen 10 mins ago
|
// lets make machine1 lastseen 10 mins ago
|
||||||
before := now.Add(-10 * time.Minute)
|
before := now.Add(-10 * time.Minute)
|
||||||
machine1.LastSeen = &before
|
machine1.LastSeen = &before
|
||||||
err = app.db.Save(&machine1).Error
|
err = app.db.db.Save(&machine1).Error
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.handlePrimarySubnetFailover()
|
err = app.db.handlePrimarySubnetFailover()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
routes, err = app.getMachinePrimaryRoutes(&machine1)
|
routes, err = app.db.getMachinePrimaryRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 1)
|
c.Assert(len(routes), check.Equals, 1)
|
||||||
|
|
||||||
routes, err = app.getMachinePrimaryRoutes(&machine2)
|
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 1)
|
c.Assert(len(routes), check.Equals, 1)
|
||||||
|
|
||||||
machine2.HostInfo = HostInfo(tailcfg.Hostinfo{
|
machine2.HostInfo = HostInfo(tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{prefix, prefix2},
|
RoutableIPs: []netip.Prefix{prefix, prefix2},
|
||||||
})
|
})
|
||||||
err = app.db.Save(&machine2).Error
|
err = app.db.db.Save(&machine2).Error
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.processMachineRoutes(&machine2)
|
err = app.db.processMachineRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine2, prefix.String())
|
err = app.db.enableRoutes(&machine2, prefix.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.handlePrimarySubnetFailover()
|
err = app.db.handlePrimarySubnetFailover()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
routes, err = app.getMachinePrimaryRoutes(&machine1)
|
routes, err = app.db.getMachinePrimaryRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 0)
|
c.Assert(len(routes), check.Equals, 0)
|
||||||
|
|
||||||
routes, err = app.getMachinePrimaryRoutes(&machine2)
|
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 2)
|
c.Assert(len(routes), check.Equals, 2)
|
||||||
}
|
}
|
||||||
|
@ -358,13 +359,13 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||||
// including both the primary routes the node is responsible for, and the
|
// including both the primary routes the node is responsible for, and the
|
||||||
// exit node routes if enabled.
|
// exit node routes if enabled.
|
||||||
func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "test_enable_route_machine")
|
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(
|
prefix, err := netip.ParsePrefix(
|
||||||
|
@ -398,9 +399,9 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
machine1 := Machine{
|
machine1 := Machine{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()),
|
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||||
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
|
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||||
DiscoKey: DiscoPublicKeyStripPrefix(discoKey.Public()),
|
DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()),
|
||||||
Hostname: "test_enable_route_machine",
|
Hostname: "test_enable_route_machine",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
|
@ -408,23 +409,23 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||||
HostInfo: HostInfo(hostInfo1),
|
HostInfo: HostInfo(hostInfo1),
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
}
|
}
|
||||||
app.db.Save(&machine1)
|
app.db.db.Save(&machine1)
|
||||||
|
|
||||||
err = app.processMachineRoutes(&machine1)
|
err = app.db.processMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine1, prefix.String())
|
err = app.db.enableRoutes(&machine1, prefix.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// We do not enable this one on purpose to test that it is not enabled
|
// We do not enable this one on purpose to test that it is not enabled
|
||||||
// err = app.enableRoutes(&machine1, prefix2.String())
|
// err = app.db.enableRoutes(&machine1, prefix2.String())
|
||||||
// c.Assert(err, check.IsNil)
|
// c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
routes, err := app.GetMachineRoutes(&machine1)
|
routes, err := app.db.GetMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.isExitRoute() {
|
if route.isExitRoute() {
|
||||||
err = app.EnableRoute(uint64(route.ID))
|
err = app.db.EnableRoute(uint64(route.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// We only enable one exit route, so we can test that both are enabled
|
// We only enable one exit route, so we can test that both are enabled
|
||||||
|
@ -432,14 +433,14 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = app.handlePrimarySubnetFailover()
|
err = app.db.handlePrimarySubnetFailover()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := app.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 3)
|
c.Assert(len(enabledRoutes1), check.Equals, 3)
|
||||||
|
|
||||||
peer, err := app.toNode(machine1, "headscale.net", nil)
|
peer, err := app.db.toNode(machine1, app.aclPolicy, "headscale.net", nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(len(peer.AllowedIPs), check.Equals, 3)
|
c.Assert(len(peer.AllowedIPs), check.Equals, 3)
|
||||||
|
@ -469,35 +470,35 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = app.DisableRoute(uint64(exitRouteV4.ID))
|
err = app.db.DisableRoute(uint64(exitRouteV4.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err = app.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||||
|
|
||||||
// and now we delete only one of the exit routes
|
// and now we delete only one of the exit routes
|
||||||
// and we check if both are deleted
|
// and we check if both are deleted
|
||||||
routes, err = app.GetMachineRoutes(&machine1)
|
routes, err = app.db.GetMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 4)
|
c.Assert(len(routes), check.Equals, 4)
|
||||||
|
|
||||||
err = app.DeleteRoute(uint64(exitRouteV4.ID))
|
err = app.db.DeleteRoute(uint64(exitRouteV4.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
routes, err = app.GetMachineRoutes(&machine1)
|
routes, err = app.db.GetMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 2)
|
c.Assert(len(routes), check.Equals, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDeleteRoutes(c *check.C) {
|
func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine("test", "test_enable_route_machine")
|
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(
|
prefix, err := netip.ParsePrefix(
|
||||||
|
@ -527,24 +528,24 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
HostInfo: HostInfo(hostInfo1),
|
HostInfo: HostInfo(hostInfo1),
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
}
|
}
|
||||||
app.db.Save(&machine1)
|
app.db.db.Save(&machine1)
|
||||||
|
|
||||||
err = app.processMachineRoutes(&machine1)
|
err = app.db.processMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine1, prefix.String())
|
err = app.db.enableRoutes(&machine1, prefix.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.enableRoutes(&machine1, prefix2.String())
|
err = app.db.enableRoutes(&machine1, prefix2.String())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
routes, err := app.GetMachineRoutes(&machine1)
|
routes, err := app.db.GetMachineRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.DeleteRoute(uint64(routes[0].ID))
|
err = app.db.DeleteRoute(uint64(routes[0].ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
enabledRoutes1, err := app.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,17 +9,18 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
ErrUserExists = Error("User already exists")
|
ErrUserExists = errors.New("user already exists")
|
||||||
ErrUserNotFound = Error("User not found")
|
ErrUserNotFound = errors.New("user not found")
|
||||||
ErrUserStillHasNodes = Error("User not empty: node(s) found")
|
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||||
ErrInvalidUserName = Error("Invalid user name")
|
ErrInvalidUserName = errors.New("invalid user name")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -40,17 +41,17 @@ type User struct {
|
||||||
|
|
||||||
// CreateUser creates a new User. Returns error if could not be created
|
// CreateUser creates a new User. Returns error if could not be created
|
||||||
// or another user already exists.
|
// or another user already exists.
|
||||||
func (h *Headscale) CreateUser(name string) (*User, error) {
|
func (hsdb *HSDatabase) CreateUser(name string) (*User, error) {
|
||||||
err := CheckForFQDNRules(name)
|
err := CheckForFQDNRules(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
user := User{}
|
user := User{}
|
||||||
if err := h.db.Where("name = ?", name).First(&user).Error; err == nil {
|
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
|
||||||
return nil, ErrUserExists
|
return nil, ErrUserExists
|
||||||
}
|
}
|
||||||
user.Name = name
|
user.Name = name
|
||||||
if err := h.db.Create(&user).Error; err != nil {
|
if err := hsdb.db.Create(&user).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "CreateUser").
|
Str("func", "CreateUser").
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -64,13 +65,13 @@ func (h *Headscale) CreateUser(name string) (*User, error) {
|
||||||
|
|
||||||
// DestroyUser destroys a User. Returns error if the User does
|
// DestroyUser destroys a User. Returns error if the User does
|
||||||
// not exist or if there are machines associated with it.
|
// not exist or if there are machines associated with it.
|
||||||
func (h *Headscale) DestroyUser(name string) error {
|
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||||
user, err := h.GetUser(name)
|
user, err := hsdb.GetUser(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrUserNotFound
|
return ErrUserNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
machines, err := h.ListMachinesByUser(name)
|
machines, err := hsdb.ListMachinesByUser(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -78,18 +79,18 @@ func (h *Headscale) DestroyUser(name string) error {
|
||||||
return ErrUserStillHasNodes
|
return ErrUserStillHasNodes
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, err := h.ListPreAuthKeys(name)
|
keys, err := hsdb.ListPreAuthKeys(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
err = h.DestroyPreAuthKey(key)
|
err = hsdb.DestroyPreAuthKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if result := h.db.Unscoped().Delete(&user); result.Error != nil {
|
if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,9 +99,9 @@ func (h *Headscale) DestroyUser(name string) error {
|
||||||
|
|
||||||
// RenameUser renames a User. Returns error if the User does
|
// RenameUser renames a User. Returns error if the User does
|
||||||
// not exist or if another User exists with the new name.
|
// not exist or if another User exists with the new name.
|
||||||
func (h *Headscale) RenameUser(oldName, newName string) error {
|
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||||
var err error
|
var err error
|
||||||
oldUser, err := h.GetUser(oldName)
|
oldUser, err := hsdb.GetUser(oldName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -108,7 +109,7 @@ func (h *Headscale) RenameUser(oldName, newName string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = h.GetUser(newName)
|
_, err = hsdb.GetUser(newName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return ErrUserExists
|
return ErrUserExists
|
||||||
}
|
}
|
||||||
|
@ -118,7 +119,7 @@ func (h *Headscale) RenameUser(oldName, newName string) error {
|
||||||
|
|
||||||
oldUser.Name = newName
|
oldUser.Name = newName
|
||||||
|
|
||||||
if result := h.db.Save(&oldUser); result.Error != nil {
|
if result := hsdb.db.Save(&oldUser); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,9 +127,9 @@ func (h *Headscale) RenameUser(oldName, newName string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUser fetches a user by name.
|
// GetUser fetches a user by name.
|
||||||
func (h *Headscale) GetUser(name string) (*User, error) {
|
func (hsdb *HSDatabase) GetUser(name string) (*User, error) {
|
||||||
user := User{}
|
user := User{}
|
||||||
if result := h.db.First(&user, "name = ?", name); errors.Is(
|
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
) {
|
) {
|
||||||
|
@ -139,9 +140,9 @@ func (h *Headscale) GetUser(name string) (*User, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListUsers gets all the existing users.
|
// ListUsers gets all the existing users.
|
||||||
func (h *Headscale) ListUsers() ([]User, error) {
|
func (hsdb *HSDatabase) ListUsers() ([]User, error) {
|
||||||
users := []User{}
|
users := []User{}
|
||||||
if err := h.db.Find(&users).Error; err != nil {
|
if err := hsdb.db.Find(&users).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,18 +150,18 @@ func (h *Headscale) ListUsers() ([]User, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListMachinesByUser gets all the nodes in a given user.
|
// ListMachinesByUser gets all the nodes in a given user.
|
||||||
func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) {
|
func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) {
|
||||||
err := CheckForFQDNRules(name)
|
err := CheckForFQDNRules(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
user, err := h.GetUser(name)
|
user, err := hsdb.GetUser(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
machines := []Machine{}
|
machines := []Machine{}
|
||||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
|
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,17 +169,17 @@ func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMachineUser assigns a Machine to a user.
|
// SetMachineUser assigns a Machine to a user.
|
||||||
func (h *Headscale) SetMachineUser(machine *Machine, username string) error {
|
func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error {
|
||||||
err := CheckForFQDNRules(username)
|
err := CheckForFQDNRules(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
user, err := h.GetUser(username)
|
user, err := hsdb.GetUser(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
machine.User = *user
|
machine.User = *user
|
||||||
if result := h.db.Save(&machine); result.Error != nil {
|
if result := hsdb.db.Save(&machine); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,7 +212,7 @@ func (n *User) toTailscaleLogin() *tailcfg.Login {
|
||||||
return &login
|
return &login
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getMapResponseUserProfiles(
|
func (hsdb *HSDatabase) getMapResponseUserProfiles(
|
||||||
machine Machine,
|
machine Machine,
|
||||||
peers Machines,
|
peers Machines,
|
||||||
) []tailcfg.UserProfile {
|
) []tailcfg.UserProfile {
|
||||||
|
@ -225,8 +226,8 @@ func (h *Headscale) getMapResponseUserProfiles(
|
||||||
for _, user := range userMap {
|
for _, user := range userMap {
|
||||||
displayName := user.Name
|
displayName := user.Name
|
||||||
|
|
||||||
if h.cfg.BaseDomain != "" {
|
if hsdb.baseDomain != "" {
|
||||||
displayName = fmt.Sprintf("%s@%s", user.Name, h.cfg.BaseDomain)
|
displayName = fmt.Sprintf("%s@%s", user.Name, hsdb.baseDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles = append(profiles,
|
profiles = append(profiles,
|
||||||
|
@ -242,7 +243,7 @@ func (h *Headscale) getMapResponseUserProfiles(
|
||||||
|
|
||||||
func (n *User) toProto() *v1.User {
|
func (n *User) toProto() *v1.User {
|
||||||
return &v1.User{
|
return &v1.User{
|
||||||
Id: strconv.FormatUint(uint64(n.ID), Base10),
|
Id: strconv.FormatUint(uint64(n.ID), util.Base10),
|
||||||
Name: n.Name,
|
Name: n.Name,
|
||||||
CreatedAt: timestamppb.New(n.CreatedAt),
|
CreatedAt: timestamppb.New(n.CreatedAt),
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,42 +9,42 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(user.Name, check.Equals, "test")
|
c.Assert(user.Name, check.Equals, "test")
|
||||||
|
|
||||||
users, err := app.ListUsers()
|
users, err := app.db.ListUsers()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(users), check.Equals, 1)
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
err = app.DestroyUser("test")
|
err = app.db.DestroyUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetUser("test")
|
_, err = app.db.GetUser("test")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
err := app.DestroyUser("test")
|
err := app.db.DestroyUser("test")
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
user, err := app.CreateUser("test")
|
user, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.DestroyUser("test")
|
err = app.db.DestroyUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
result := app.db.Preload("User").First(&pak, "key = ?", pak.Key)
|
result := app.db.db.Preload("User").First(&pak, "key = ?", pak.Key)
|
||||||
// destroying a user also deletes all associated preauthkeys
|
// destroying a user also deletes all associated preauthkeys
|
||||||
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
||||||
|
|
||||||
user, err = app.CreateUser("test")
|
user, err = app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err = app.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
|
@ -57,52 +57,52 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
|
|
||||||
err = app.DestroyUser("test")
|
err = app.db.DestroyUser("test")
|
||||||
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestRenameUser(c *check.C) {
|
func (s *Suite) TestRenameUser(c *check.C) {
|
||||||
userTest, err := app.CreateUser("test")
|
userTest, err := app.db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(userTest.Name, check.Equals, "test")
|
c.Assert(userTest.Name, check.Equals, "test")
|
||||||
|
|
||||||
users, err := app.ListUsers()
|
users, err := app.db.ListUsers()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(users), check.Equals, 1)
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
err = app.RenameUser("test", "test-renamed")
|
err = app.db.RenameUser("test", "test-renamed")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetUser("test")
|
_, err = app.db.GetUser("test")
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
_, err = app.GetUser("test-renamed")
|
_, err = app.db.GetUser("test-renamed")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = app.RenameUser("test-does-not-exit", "test")
|
err = app.db.RenameUser("test-does-not-exit", "test")
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
userTest2, err := app.CreateUser("test2")
|
userTest2, err := app.db.CreateUser("test2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||||
|
|
||||||
err = app.RenameUser("test2", "test-renamed")
|
err = app.db.RenameUser("test2", "test-renamed")
|
||||||
c.Assert(err, check.Equals, ErrUserExists)
|
c.Assert(err, check.Equals, ErrUserExists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
userShared1, err := app.CreateUser("shared1")
|
userShared1, err := app.db.CreateUser("shared1")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
userShared2, err := app.CreateUser("shared2")
|
userShared2, err := app.db.CreateUser("shared2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
userShared3, err := app.CreateUser("shared3")
|
userShared3, err := app.db.CreateUser("shared3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
preAuthKeyShared1, err := app.CreatePreAuthKey(
|
preAuthKeyShared1, err := app.db.CreatePreAuthKey(
|
||||||
userShared1.Name,
|
userShared1.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -111,7 +111,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
preAuthKeyShared2, err := app.CreatePreAuthKey(
|
preAuthKeyShared2, err := app.db.CreatePreAuthKey(
|
||||||
userShared2.Name,
|
userShared2.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -120,7 +120,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
preAuthKeyShared3, err := app.CreatePreAuthKey(
|
preAuthKeyShared3, err := app.db.CreatePreAuthKey(
|
||||||
userShared3.Name,
|
userShared3.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -129,7 +129,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
preAuthKey2Shared1, err := app.CreatePreAuthKey(
|
preAuthKey2Shared1, err := app.db.CreatePreAuthKey(
|
||||||
userShared1.Name,
|
userShared1.Name,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -138,7 +138,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
)
|
)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
machineInShared1 := &Machine{
|
machineInShared1 := &Machine{
|
||||||
|
@ -153,9 +153,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||||
AuthKeyID: uint(preAuthKeyShared1.ID),
|
AuthKeyID: uint(preAuthKeyShared1.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machineInShared1)
|
app.db.db.Save(machineInShared1)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machineInShared2 := &Machine{
|
machineInShared2 := &Machine{
|
||||||
|
@ -170,9 +170,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||||
AuthKeyID: uint(preAuthKeyShared2.ID),
|
AuthKeyID: uint(preAuthKeyShared2.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machineInShared2)
|
app.db.db.Save(machineInShared2)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machineInShared3 := &Machine{
|
machineInShared3 := &Machine{
|
||||||
|
@ -187,9 +187,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||||
AuthKeyID: uint(preAuthKeyShared3.ID),
|
AuthKeyID: uint(preAuthKeyShared3.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machineInShared3)
|
app.db.db.Save(machineInShared3)
|
||||||
|
|
||||||
_, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine2InShared1 := &Machine{
|
machine2InShared1 := &Machine{
|
||||||
|
@ -204,12 +204,12 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||||
AuthKeyID: uint(preAuthKey2Shared1.ID),
|
AuthKeyID: uint(preAuthKey2Shared1.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(machine2InShared1)
|
app.db.db.Save(machine2InShared1)
|
||||||
|
|
||||||
peersOfMachine1InShared1, err := app.getPeers(machineInShared1)
|
peersOfMachine1InShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
userProfiles := app.getMapResponseUserProfiles(
|
userProfiles := app.db.getMapResponseUserProfiles(
|
||||||
*machineInShared1,
|
*machineInShared1,
|
||||||
peersOfMachine1InShared1,
|
peersOfMachine1InShared1,
|
||||||
)
|
)
|
||||||
|
@ -378,13 +378,13 @@ func TestCheckForFQDNRules(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
oldUser, err := app.CreateUser("old")
|
oldUser, err := app.db.CreateUser("old")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
newUser, err := app.CreateUser("new")
|
newUser, err := app.db.CreateUser("new")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := app.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
|
pak, err := app.db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
machine := Machine{
|
machine := Machine{
|
||||||
|
@ -397,18 +397,18 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
RegisterMethod: RegisterMethodAuthKey,
|
RegisterMethod: RegisterMethodAuthKey,
|
||||||
AuthKeyID: uint(pak.ID),
|
AuthKeyID: uint(pak.ID),
|
||||||
}
|
}
|
||||||
app.db.Save(&machine)
|
app.db.db.Save(&machine)
|
||||||
c.Assert(machine.UserID, check.Equals, oldUser.ID)
|
c.Assert(machine.UserID, check.Equals, oldUser.ID)
|
||||||
|
|
||||||
err = app.SetMachineUser(&machine, newUser.Name)
|
err = app.db.SetMachineUser(&machine, newUser.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
||||||
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
||||||
|
|
||||||
err = app.SetMachineUser(&machine, "non-existing-user")
|
err = app.db.SetMachineUser(&machine, "non-existing-user")
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
err = app.SetMachineUser(&machine, newUser.Name)
|
err = app.db.SetMachineUser(&machine, newUser.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
||||||
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
||||||
|
|
42
hscontrol/util/addr.go
Normal file
42
hscontrol/util/addr.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"go4.org/netipx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) {
|
||||||
|
var network, broadcast netip.Addr
|
||||||
|
ipRange := netipx.RangeOfPrefix(na)
|
||||||
|
network = ipRange.From()
|
||||||
|
broadcast = ipRange.To()
|
||||||
|
|
||||||
|
return network, broadcast
|
||||||
|
}
|
||||||
|
|
||||||
|
func StringToIPPrefix(prefixes []string) ([]netip.Prefix, error) {
|
||||||
|
result := make([]netip.Prefix, len(prefixes))
|
||||||
|
|
||||||
|
for index, prefixStr := range prefixes {
|
||||||
|
prefix, err := netip.ParsePrefix(prefixStr)
|
||||||
|
if err != nil {
|
||||||
|
return []netip.Prefix{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result[index] = prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func StringOrPrefixListContains[T string | netip.Prefix](ts []T, t T) bool {
|
||||||
|
for _, v := range ts {
|
||||||
|
if reflect.DeepEqual(v, t) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
43
hscontrol/util/file.go
Normal file
43
hscontrol/util/file.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Base8 = 8
|
||||||
|
Base10 = 10
|
||||||
|
BitSize16 = 16
|
||||||
|
BitSize32 = 32
|
||||||
|
BitSize64 = 64
|
||||||
|
)
|
||||||
|
|
||||||
|
func AbsolutePathFromConfigPath(path string) string {
|
||||||
|
// If a relative path is provided, prefix it with the directory where
|
||||||
|
// the config file was found.
|
||||||
|
if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
|
||||||
|
dir, _ := filepath.Split(viper.ConfigFileUsed())
|
||||||
|
if dir != "" {
|
||||||
|
path = filepath.Join(dir, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFileMode(key string) fs.FileMode {
|
||||||
|
modeStr := viper.GetString(key)
|
||||||
|
|
||||||
|
mode, err := strconv.ParseUint(modeStr, Base8, BitSize64)
|
||||||
|
if err != nil {
|
||||||
|
return PermissionFallback
|
||||||
|
}
|
||||||
|
|
||||||
|
return fs.FileMode(mode)
|
||||||
|
}
|
117
hscontrol/util/key.go
Normal file
117
hscontrol/util/key.go
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
|
||||||
|
// These constants are copied from the upstream tailscale.com/types/key
|
||||||
|
// library, because they are not exported.
|
||||||
|
// https://github.com/tailscale/tailscale/tree/main/types/key
|
||||||
|
|
||||||
|
// nodePublicHexPrefix is the prefix used to identify a
|
||||||
|
// hex-encoded node public key.
|
||||||
|
//
|
||||||
|
// This prefix is used in the control protocol, so cannot be
|
||||||
|
// changed.
|
||||||
|
nodePublicHexPrefix = "nodekey:"
|
||||||
|
|
||||||
|
// machinePublicHexPrefix is the prefix used to identify a
|
||||||
|
// hex-encoded machine public key.
|
||||||
|
//
|
||||||
|
// This prefix is used in the control protocol, so cannot be
|
||||||
|
// changed.
|
||||||
|
machinePublicHexPrefix = "mkey:"
|
||||||
|
|
||||||
|
// discoPublicHexPrefix is the prefix used to identify a
|
||||||
|
// hex-encoded disco public key.
|
||||||
|
//
|
||||||
|
// This prefix is used in the control protocol, so cannot be
|
||||||
|
// changed.
|
||||||
|
discoPublicHexPrefix = "discokey:"
|
||||||
|
|
||||||
|
// privateKey prefix.
|
||||||
|
privateHexPrefix = "privkey:"
|
||||||
|
|
||||||
|
PermissionFallback = 0o700
|
||||||
|
|
||||||
|
ZstdCompression = "zstd"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+")
|
||||||
|
ErrCannotDecryptResponse = errors.New("cannot decrypt response")
|
||||||
|
)
|
||||||
|
|
||||||
|
func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string {
|
||||||
|
return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string {
|
||||||
|
return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string {
|
||||||
|
return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MachinePublicKeyEnsurePrefix(machineKey string) string {
|
||||||
|
if !strings.HasPrefix(machineKey, machinePublicHexPrefix) {
|
||||||
|
return machinePublicHexPrefix + machineKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return machineKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func NodePublicKeyEnsurePrefix(nodeKey string) string {
|
||||||
|
if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) {
|
||||||
|
return nodePublicHexPrefix + nodeKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodeKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func DiscoPublicKeyEnsurePrefix(discoKey string) string {
|
||||||
|
if !strings.HasPrefix(discoKey, discoPublicHexPrefix) {
|
||||||
|
return discoPublicHexPrefix + discoKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return discoKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrivateKeyEnsurePrefix(privateKey string) string {
|
||||||
|
if !strings.HasPrefix(privateKey, privateHexPrefix) {
|
||||||
|
return privateHexPrefix + privateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return privateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecodeAndUnmarshalNaCl(
|
||||||
|
msg []byte,
|
||||||
|
output interface{},
|
||||||
|
pubKey *key.MachinePublic,
|
||||||
|
privKey *key.MachinePrivate,
|
||||||
|
) error {
|
||||||
|
// log.Trace().
|
||||||
|
// Str("pubkey", pubKey.ShortString()).
|
||||||
|
// Int("length", len(msg)).
|
||||||
|
// Msg("Trying to decrypt")
|
||||||
|
|
||||||
|
decrypted, ok := privKey.OpenFrom(*pubKey, msg)
|
||||||
|
if !ok {
|
||||||
|
return ErrCannotDecryptResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(decrypted, output); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
12
hscontrol/util/net.go
Normal file
12
hscontrol/util/net.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
|
var d net.Dialer
|
||||||
|
|
||||||
|
return d.DialContext(ctx, "unix", addr)
|
||||||
|
}
|
85
hscontrol/util/string.go
Normal file
85
hscontrol/util/string.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateRandomBytes returns securely generated random bytes.
|
||||||
|
// It will return an error if the system's secure random
|
||||||
|
// number generator fails to function correctly, in which
|
||||||
|
// case the caller should not continue.
|
||||||
|
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||||
|
bytes := make([]byte, n)
|
||||||
|
|
||||||
|
// Note that err == nil only if we read len(b) bytes.
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRandomStringURLSafe returns a URL-safe, base64 encoded
|
||||||
|
// securely generated random string.
|
||||||
|
// It will return an error if the system's secure random
|
||||||
|
// number generator fails to function correctly, in which
|
||||||
|
// case the caller should not continue.
|
||||||
|
func GenerateRandomStringURLSafe(n int) (string, error) {
|
||||||
|
b, err := GenerateRandomBytes(n)
|
||||||
|
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRandomStringDNSSafe returns a DNS-safe
|
||||||
|
// securely generated random string.
|
||||||
|
// It will return an error if the system's secure random
|
||||||
|
// number generator fails to function correctly, in which
|
||||||
|
// case the caller should not continue.
|
||||||
|
func GenerateRandomStringDNSSafe(size int) (string, error) {
|
||||||
|
var str string
|
||||||
|
var err error
|
||||||
|
for len(str) < size {
|
||||||
|
str, err = GenerateRandomStringURLSafe(size)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
str = strings.ToLower(
|
||||||
|
strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", ""),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return str[:size], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsStringInSlice(slice []string, str string) bool {
|
||||||
|
for _, s := range slice {
|
||||||
|
if s == str {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TailNodesToString(nodes []*tailcfg.Node) string {
|
||||||
|
temp := make([]string, len(nodes))
|
||||||
|
|
||||||
|
for index, node := range nodes {
|
||||||
|
temp[index] = node.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TailMapResponseToString(resp tailcfg.MapResponse) string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"{ Node: %s, Peers: %s }",
|
||||||
|
resp.Node.Name,
|
||||||
|
TailNodesToString(resp.Peers),
|
||||||
|
)
|
||||||
|
}
|
15
hscontrol/util/string_test.go
Normal file
15
hscontrol/util/string_test.go
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateRandomStringDNSSafe(t *testing.T) {
|
||||||
|
for i := 0; i < 100000; i++ {
|
||||||
|
str, err := GenerateRandomStringDNSSafe(8)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Len(t, str, 8)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,361 +0,0 @@
|
||||||
// Codehere is mostly taken from github.com/tailscale/tailscale
|
|
||||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package hscontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"reflect"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/spf13/viper"
|
|
||||||
"go4.org/netipx"
|
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
"tailscale.com/types/key"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ErrCannotDecryptResponse = Error("cannot decrypt response")
|
|
||||||
ErrCouldNotAllocateIP = Error("could not find any suitable IP")
|
|
||||||
|
|
||||||
// These constants are copied from the upstream tailscale.com/types/key
|
|
||||||
// library, because they are not exported.
|
|
||||||
// https://github.com/tailscale/tailscale/tree/main/types/key
|
|
||||||
|
|
||||||
// nodePublicHexPrefix is the prefix used to identify a
|
|
||||||
// hex-encoded node public key.
|
|
||||||
//
|
|
||||||
// This prefix is used in the control protocol, so cannot be
|
|
||||||
// changed.
|
|
||||||
nodePublicHexPrefix = "nodekey:"
|
|
||||||
|
|
||||||
// machinePublicHexPrefix is the prefix used to identify a
|
|
||||||
// hex-encoded machine public key.
|
|
||||||
//
|
|
||||||
// This prefix is used in the control protocol, so cannot be
|
|
||||||
// changed.
|
|
||||||
machinePublicHexPrefix = "mkey:"
|
|
||||||
|
|
||||||
// discoPublicHexPrefix is the prefix used to identify a
|
|
||||||
// hex-encoded disco public key.
|
|
||||||
//
|
|
||||||
// This prefix is used in the control protocol, so cannot be
|
|
||||||
// changed.
|
|
||||||
discoPublicHexPrefix = "discokey:"
|
|
||||||
|
|
||||||
// privateKey prefix.
|
|
||||||
privateHexPrefix = "privkey:"
|
|
||||||
|
|
||||||
PermissionFallback = 0o700
|
|
||||||
|
|
||||||
ZstdCompression = "zstd"
|
|
||||||
)
|
|
||||||
|
|
||||||
var NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+")
|
|
||||||
|
|
||||||
func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string {
|
|
||||||
return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string {
|
|
||||||
return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string {
|
|
||||||
return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MachinePublicKeyEnsurePrefix(machineKey string) string {
|
|
||||||
if !strings.HasPrefix(machineKey, machinePublicHexPrefix) {
|
|
||||||
return machinePublicHexPrefix + machineKey
|
|
||||||
}
|
|
||||||
|
|
||||||
return machineKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func NodePublicKeyEnsurePrefix(nodeKey string) string {
|
|
||||||
if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) {
|
|
||||||
return nodePublicHexPrefix + nodeKey
|
|
||||||
}
|
|
||||||
|
|
||||||
return nodeKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func DiscoPublicKeyEnsurePrefix(discoKey string) string {
|
|
||||||
if !strings.HasPrefix(discoKey, discoPublicHexPrefix) {
|
|
||||||
return discoPublicHexPrefix + discoKey
|
|
||||||
}
|
|
||||||
|
|
||||||
return discoKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func PrivateKeyEnsurePrefix(privateKey string) string {
|
|
||||||
if !strings.HasPrefix(privateKey, privateHexPrefix) {
|
|
||||||
return privateHexPrefix + privateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
return privateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
|
|
||||||
type Error string
|
|
||||||
|
|
||||||
func (e Error) Error() string { return string(e) }
|
|
||||||
|
|
||||||
func decode(
|
|
||||||
msg []byte,
|
|
||||||
output interface{},
|
|
||||||
pubKey *key.MachinePublic,
|
|
||||||
privKey *key.MachinePrivate,
|
|
||||||
) error {
|
|
||||||
log.Trace().
|
|
||||||
Str("pubkey", pubKey.ShortString()).
|
|
||||||
Int("length", len(msg)).
|
|
||||||
Msg("Trying to decrypt")
|
|
||||||
|
|
||||||
decrypted, ok := privKey.OpenFrom(*pubKey, msg)
|
|
||||||
if !ok {
|
|
||||||
return ErrCannotDecryptResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(decrypted, output); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) getAvailableIPs() (MachineAddresses, error) {
|
|
||||||
var ips MachineAddresses
|
|
||||||
var err error
|
|
||||||
ipPrefixes := h.cfg.IPPrefixes
|
|
||||||
for _, ipPrefix := range ipPrefixes {
|
|
||||||
var ip *netip.Addr
|
|
||||||
ip, err = h.getAvailableIP(ipPrefix)
|
|
||||||
if err != nil {
|
|
||||||
return ips, err
|
|
||||||
}
|
|
||||||
ips = append(ips, *ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ips, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) {
|
|
||||||
var network, broadcast netip.Addr
|
|
||||||
ipRange := netipx.RangeOfPrefix(na)
|
|
||||||
network = ipRange.From()
|
|
||||||
broadcast = ipRange.To()
|
|
||||||
|
|
||||||
return network, broadcast
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
|
|
||||||
usedIps, err := h.getUsedIPs()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ipPrefixNetworkAddress, ipPrefixBroadcastAddress := GetIPPrefixEndpoints(ipPrefix)
|
|
||||||
|
|
||||||
// Get the first IP in our prefix
|
|
||||||
ip := ipPrefixNetworkAddress.Next()
|
|
||||||
|
|
||||||
for {
|
|
||||||
if !ipPrefix.Contains(ip) {
|
|
||||||
return nil, ErrCouldNotAllocateIP
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case ip.Compare(ipPrefixBroadcastAddress) == 0:
|
|
||||||
fallthrough
|
|
||||||
case usedIps.Contains(ip):
|
|
||||||
fallthrough
|
|
||||||
case ip == netip.Addr{} || ip.IsLoopback():
|
|
||||||
ip = ip.Next()
|
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
default:
|
|
||||||
return &ip, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) getUsedIPs() (*netipx.IPSet, error) {
|
|
||||||
// FIXME: This really deserves a better data model,
|
|
||||||
// but this was quick to get running and it should be enough
|
|
||||||
// to begin experimenting with a dual stack tailnet.
|
|
||||||
var addressesSlices []string
|
|
||||||
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
|
|
||||||
|
|
||||||
var ips netipx.IPSetBuilder
|
|
||||||
for _, slice := range addressesSlices {
|
|
||||||
var machineAddresses MachineAddresses
|
|
||||||
err := machineAddresses.Scan(slice)
|
|
||||||
if err != nil {
|
|
||||||
return &netipx.IPSet{}, fmt.Errorf(
|
|
||||||
"failed to read ip from database: %w",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ip := range machineAddresses {
|
|
||||||
ips.Add(ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ipSet, err := ips.IPSet()
|
|
||||||
if err != nil {
|
|
||||||
return &netipx.IPSet{}, fmt.Errorf(
|
|
||||||
"failed to build IP Set: %w",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ipSet, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func tailNodesToString(nodes []*tailcfg.Node) string {
|
|
||||||
temp := make([]string, len(nodes))
|
|
||||||
|
|
||||||
for index, node := range nodes {
|
|
||||||
temp[index] = node.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
|
||||||
}
|
|
||||||
|
|
||||||
func tailMapResponseToString(resp tailcfg.MapResponse) string {
|
|
||||||
return fmt.Sprintf(
|
|
||||||
"{ Node: %s, Peers: %s }",
|
|
||||||
resp.Node.Name,
|
|
||||||
tailNodesToString(resp.Peers),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
|
||||||
var d net.Dialer
|
|
||||||
|
|
||||||
return d.DialContext(ctx, "unix", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func stringToIPPrefix(prefixes []string) ([]netip.Prefix, error) {
|
|
||||||
result := make([]netip.Prefix, len(prefixes))
|
|
||||||
|
|
||||||
for index, prefixStr := range prefixes {
|
|
||||||
prefix, err := netip.ParsePrefix(prefixStr)
|
|
||||||
if err != nil {
|
|
||||||
return []netip.Prefix{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
result[index] = prefix
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func contains[T string | netip.Prefix](ts []T, t T) bool {
|
|
||||||
for _, v := range ts {
|
|
||||||
if reflect.DeepEqual(v, t) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateRandomBytes returns securely generated random bytes.
|
|
||||||
// It will return an error if the system's secure random
|
|
||||||
// number generator fails to function correctly, in which
|
|
||||||
// case the caller should not continue.
|
|
||||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
|
||||||
bytes := make([]byte, n)
|
|
||||||
|
|
||||||
// Note that err == nil only if we read len(b) bytes.
|
|
||||||
if _, err := rand.Read(bytes); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return bytes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateRandomStringURLSafe returns a URL-safe, base64 encoded
|
|
||||||
// securely generated random string.
|
|
||||||
// It will return an error if the system's secure random
|
|
||||||
// number generator fails to function correctly, in which
|
|
||||||
// case the caller should not continue.
|
|
||||||
func GenerateRandomStringURLSafe(n int) (string, error) {
|
|
||||||
b, err := GenerateRandomBytes(n)
|
|
||||||
|
|
||||||
return base64.RawURLEncoding.EncodeToString(b), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateRandomStringDNSSafe returns a DNS-safe
|
|
||||||
// securely generated random string.
|
|
||||||
// It will return an error if the system's secure random
|
|
||||||
// number generator fails to function correctly, in which
|
|
||||||
// case the caller should not continue.
|
|
||||||
func GenerateRandomStringDNSSafe(size int) (string, error) {
|
|
||||||
var str string
|
|
||||||
var err error
|
|
||||||
for len(str) < size {
|
|
||||||
str, err = GenerateRandomStringURLSafe(size)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
str = strings.ToLower(
|
|
||||||
strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", ""),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return str[:size], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsStringInSlice(slice []string, str string) bool {
|
|
||||||
for _, s := range slice {
|
|
||||||
if s == str {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func AbsolutePathFromConfigPath(path string) string {
|
|
||||||
// If a relative path is provided, prefix it with the directory where
|
|
||||||
// the config file was found.
|
|
||||||
if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
|
|
||||||
dir, _ := filepath.Split(viper.ConfigFileUsed())
|
|
||||||
if dir != "" {
|
|
||||||
path = filepath.Join(dir, path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return path
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetFileMode(key string) fs.FileMode {
|
|
||||||
modeStr := viper.GetString(key)
|
|
||||||
|
|
||||||
mode, err := strconv.ParseUint(modeStr, Base8, BitSize64)
|
|
||||||
if err != nil {
|
|
||||||
return PermissionFallback
|
|
||||||
}
|
|
||||||
|
|
||||||
return fs.FileMode(mode)
|
|
||||||
}
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
|
@ -220,7 +221,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*hscontrol.OIDC
|
||||||
}
|
}
|
||||||
portNotation := fmt.Sprintf("%d/tcp", port)
|
portNotation := fmt.Sprintf("%d/tcp", port)
|
||||||
|
|
||||||
hash, _ := hscontrol.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
|
hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
|
||||||
|
|
||||||
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
|
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
|
@ -110,7 +110,7 @@ func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
hash, err := hscontrol.GenerateRandomStringDNSSafe(scenarioHashLength)
|
hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
"github.com/juanfont/headscale/integration/integrationutil"
|
"github.com/juanfont/headscale/integration/integrationutil"
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
|
@ -132,7 +133,7 @@ func WithHostPortBindings(bindings map[string][]string) Option {
|
||||||
// in the Docker container name.
|
// in the Docker container name.
|
||||||
func WithTestName(testName string) Option {
|
func WithTestName(testName string) Option {
|
||||||
return func(hsic *HeadscaleInContainer) {
|
return func(hsic *HeadscaleInContainer) {
|
||||||
hash, _ := hscontrol.GenerateRandomStringDNSSafe(hsicHashLength)
|
hash, _ := util.GenerateRandomStringDNSSafe(hsicHashLength)
|
||||||
|
|
||||||
hostname := fmt.Sprintf("hs-%s-%s", testName, hash)
|
hostname := fmt.Sprintf("hs-%s-%s", testName, hash)
|
||||||
hsic.hostname = hostname
|
hsic.hostname = hostname
|
||||||
|
@ -167,7 +168,7 @@ func New(
|
||||||
network *dockertest.Network,
|
network *dockertest.Network,
|
||||||
opts ...Option,
|
opts ...Option,
|
||||||
) (*HeadscaleInContainer, error) {
|
) (*HeadscaleInContainer, error) {
|
||||||
hash, err := hscontrol.GenerateRandomStringDNSSafe(hsicHashLength)
|
hash, err := util.GenerateRandomStringDNSSafe(hsicHashLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
|
@ -105,7 +105,7 @@ type Scenario struct {
|
||||||
// NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with
|
// NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with
|
||||||
// a set of Users and TailscaleClients.
|
// a set of Users and TailscaleClients.
|
||||||
func NewScenario() (*Scenario, error) {
|
func NewScenario() (*Scenario, error) {
|
||||||
hash, err := hscontrol.GenerateRandomStringDNSSafe(scenarioHashLength)
|
hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
"github.com/juanfont/headscale/integration/integrationutil"
|
"github.com/juanfont/headscale/integration/integrationutil"
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
|
@ -150,7 +150,7 @@ func New(
|
||||||
network *dockertest.Network,
|
network *dockertest.Network,
|
||||||
opts ...Option,
|
opts ...Option,
|
||||||
) (*TailscaleInContainer, error) {
|
) (*TailscaleInContainer, error) {
|
||||||
hash, err := hscontrol.GenerateRandomStringDNSSafe(tsicHashLength)
|
hash, err := util.GenerateRandomStringDNSSafe(tsicHashLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue