remove the use key stripping and store the proper keys (#1603)
This commit is contained in:
parent
2af71c9e31
commit
c0fd06e3f5
21 changed files with 99 additions and 198 deletions
|
@ -529,7 +529,7 @@ func nodesToPtables(
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
var machineKey key.MachinePublic
|
||||||
err := machineKey.UnmarshalText(
|
err := machineKey.UnmarshalText(
|
||||||
[]byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)),
|
[]byte(node.MachineKey),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
machineKey = key.MachinePublic{}
|
machineKey = key.MachinePublic{}
|
||||||
|
@ -537,7 +537,7 @@ func nodesToPtables(
|
||||||
|
|
||||||
var nodeKey key.NodePublic
|
var nodeKey key.NodePublic
|
||||||
err = nodeKey.UnmarshalText(
|
err = nodeKey.UnmarshalText(
|
||||||
[]byte(util.NodePublicKeyEnsurePrefix(node.NodeKey)),
|
[]byte(node.NodeKey),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -911,10 +911,9 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
trimmedPrivateKey := strings.TrimSpace(string(privateKey))
|
trimmedPrivateKey := strings.TrimSpace(string(privateKey))
|
||||||
privateKeyEnsurePrefix := util.PrivateKeyEnsurePrefix(trimmedPrivateKey)
|
|
||||||
|
|
||||||
var machineKey key.MachinePrivate
|
var machineKey key.MachinePrivate
|
||||||
if err = machineKey.UnmarshalText([]byte(privateKeyEnsurePrefix)); err != nil {
|
if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil {
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("path", path).
|
Str("path", path).
|
||||||
Msg("This might be due to a legacy (headscale pre-0.12) private key. " +
|
Msg("This might be due to a legacy (headscale pre-0.12) private key. " +
|
||||||
|
|
|
@ -45,7 +45,7 @@ func (h *Headscale) handleRegister(
|
||||||
// is that the client will hammer headscale with requests until it gets a
|
// is that the client will hammer headscale with requests until it gets a
|
||||||
// successful RegisterResponse.
|
// successful RegisterResponse.
|
||||||
if registerRequest.Followup != "" {
|
if registerRequest.Followup != "" {
|
||||||
if _, ok := h.registrationCache.Get(util.NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
|
if _, ok := h.registrationCache.Get(registerRequest.NodeKey.String()); ok {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Caller().
|
Caller().
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
Str("node", registerRequest.Hostinfo.Hostname).
|
||||||
|
@ -97,10 +97,10 @@ func (h *Headscale) handleRegister(
|
||||||
// We create the node and then keep it around until a callback
|
// We create the node and then keep it around until a callback
|
||||||
// happens
|
// happens
|
||||||
newNode := types.Node{
|
newNode := types.Node{
|
||||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
MachineKey: machineKey.String(),
|
||||||
Hostname: registerRequest.Hostinfo.Hostname,
|
Hostname: registerRequest.Hostinfo.Hostname,
|
||||||
GivenName: givenName,
|
GivenName: givenName,
|
||||||
NodeKey: util.NodePublicKeyStripPrefix(registerRequest.NodeKey),
|
NodeKey: registerRequest.NodeKey.String(),
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
Expiry: &time.Time{},
|
Expiry: &time.Time{},
|
||||||
}
|
}
|
||||||
|
@ -136,7 +136,7 @@ func (h *Headscale) handleRegister(
|
||||||
// So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it.
|
// So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it.
|
||||||
var storedMachineKey key.MachinePublic
|
var storedMachineKey key.MachinePublic
|
||||||
err = storedMachineKey.UnmarshalText(
|
err = storedMachineKey.UnmarshalText(
|
||||||
[]byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)),
|
[]byte(node.MachineKey),
|
||||||
)
|
)
|
||||||
if err != nil || storedMachineKey.IsZero() {
|
if err != nil || storedMachineKey.IsZero() {
|
||||||
if err := h.db.NodeSetMachineKey(node, machineKey); err != nil {
|
if err := h.db.NodeSetMachineKey(node, machineKey); err != nil {
|
||||||
|
@ -156,7 +156,7 @@ func (h *Headscale) handleRegister(
|
||||||
// - Trying to log out (sending a expiry in the past)
|
// - Trying to log out (sending a expiry in the past)
|
||||||
// - A valid, registered node, looking for /map
|
// - A valid, registered node, looking for /map
|
||||||
// - Expired node wanting to reauthenticate
|
// - Expired node wanting to reauthenticate
|
||||||
if node.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.NodeKey) {
|
if node.NodeKey == registerRequest.NodeKey.String() {
|
||||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
||||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||||
if !registerRequest.Expiry.IsZero() &&
|
if !registerRequest.Expiry.IsZero() &&
|
||||||
|
@ -176,7 +176,7 @@ func (h *Headscale) handleRegister(
|
||||||
}
|
}
|
||||||
|
|
||||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||||
if node.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
|
if node.NodeKey == registerRequest.OldNodeKey.String() &&
|
||||||
!node.IsExpired() {
|
!node.IsExpired() {
|
||||||
h.handleNodeKeyRefresh(
|
h.handleNodeKeyRefresh(
|
||||||
writer,
|
writer,
|
||||||
|
@ -207,9 +207,9 @@ func (h *Headscale) handleRegister(
|
||||||
// we need to make sure the NodeKey matches the one in the request
|
// we need to make sure the NodeKey matches the one in the request
|
||||||
// TODO(juan): What happens when using fast user switching between two
|
// TODO(juan): What happens when using fast user switching between two
|
||||||
// headscale-managed tailnets?
|
// headscale-managed tailnets?
|
||||||
node.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
node.NodeKey = registerRequest.NodeKey.String()
|
||||||
h.registrationCache.Set(
|
h.registrationCache.Set(
|
||||||
util.NodePublicKeyStripPrefix(registerRequest.NodeKey),
|
registerRequest.NodeKey.String(),
|
||||||
*node,
|
*node,
|
||||||
registerCacheExpiration,
|
registerCacheExpiration,
|
||||||
)
|
)
|
||||||
|
@ -294,7 +294,7 @@ func (h *Headscale) handleAuthKey(
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
Str("node", registerRequest.Hostinfo.Hostname).
|
||||||
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
||||||
|
|
||||||
nodeKey := util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
nodeKey := registerRequest.NodeKey.String()
|
||||||
|
|
||||||
// retrieve node information if it exist
|
// retrieve node information if it exist
|
||||||
// The error is not important, because if it does not
|
// The error is not important, because if it does not
|
||||||
|
@ -342,7 +342,7 @@ func (h *Headscale) handleAuthKey(
|
||||||
} else {
|
} else {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
givenName, err := h.db.GenerateGivenName(util.MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname)
|
givenName, err := h.db.GenerateGivenName(machineKey.String(), registerRequest.Hostinfo.Hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -359,7 +359,7 @@ func (h *Headscale) handleAuthKey(
|
||||||
Hostname: registerRequest.Hostinfo.Hostname,
|
Hostname: registerRequest.Hostinfo.Hostname,
|
||||||
GivenName: givenName,
|
GivenName: givenName,
|
||||||
UserID: pak.User.ID,
|
UserID: pak.User.ID,
|
||||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
MachineKey: machineKey.String(),
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
Expiry: ®isterRequest.Expiry,
|
Expiry: ®isterRequest.Expiry,
|
||||||
NodeKey: nodeKey,
|
NodeKey: nodeKey,
|
||||||
|
|
|
@ -33,7 +33,7 @@ func (h *Headscale) RegistrationHandler(
|
||||||
body, _ := io.ReadAll(req.Body)
|
body, _ := io.ReadAll(req.Body)
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
var machineKey key.MachinePublic
|
||||||
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr)))
|
err := machineKey.UnmarshalText([]byte("mkey:" + machineKeyStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
|
|
@ -35,9 +35,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
@ -83,9 +80,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: uint64(index),
|
ID: uint64(index),
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
@ -173,9 +167,6 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -252,6 +253,27 @@ func NewHeadscaleDatabase(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure all keys have correct prefixes
|
||||||
|
// https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35
|
||||||
|
nodes := types.Nodes{}
|
||||||
|
if err := dbConn.Find(&nodes).Error; err != nil {
|
||||||
|
log.Error().Err(err).Msg("Error accessing db")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, node := range nodes {
|
||||||
|
if !strings.HasPrefix(node.DiscoKey, "discokey:") {
|
||||||
|
node.DiscoKey = "discokey:" + node.DiscoKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(node.NodeKey, "nodekey:") {
|
||||||
|
node.NodeKey = "nodekey:" + node.NodeKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(node.MachineKey, "mkey:") {
|
||||||
|
node.MachineKey = "mkey:" + node.MachineKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(kradalby): is this needed?
|
// TODO(kradalby): is this needed?
|
||||||
err = db.setValue("db_version", dbVersion)
|
err = db.setValue("db_version", dbVersion)
|
||||||
|
|
||||||
|
|
|
@ -182,7 +182,7 @@ func (hsdb *HSDatabase) GetNodeByMachineKey(
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
Preload("Routes").
|
Preload("Routes").
|
||||||
First(&mach, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil {
|
First(&mach, "machine_key = ?", machineKey.String()); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,7 +203,7 @@ func (hsdb *HSDatabase) GetNodeByNodeKey(
|
||||||
Preload("User").
|
Preload("User").
|
||||||
Preload("Routes").
|
Preload("Routes").
|
||||||
First(&node, "node_key = ?",
|
First(&node, "node_key = ?",
|
||||||
util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil {
|
nodeKey.String()); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,9 +224,9 @@ func (hsdb *HSDatabase) GetNodeByAnyKey(
|
||||||
Preload("User").
|
Preload("User").
|
||||||
Preload("Routes").
|
Preload("Routes").
|
||||||
First(&node, "machine_key = ? OR node_key = ? OR node_key = ?",
|
First(&node, "machine_key = ? OR node_key = ? OR node_key = ?",
|
||||||
util.MachinePublicKeyStripPrefix(machineKey),
|
machineKey.String(),
|
||||||
util.NodePublicKeyStripPrefix(nodeKey),
|
nodeKey.String(),
|
||||||
util.NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
|
oldNodeKey.String()); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -397,7 +397,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
||||||
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
|
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
|
||||||
Msg("Registering node from API/CLI or auth callback")
|
Msg("Registering node from API/CLI or auth callback")
|
||||||
|
|
||||||
if nodeInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok {
|
if nodeInterface, ok := cache.Get(nodeKey.String()); ok {
|
||||||
if registrationNode, ok := nodeInterface.(types.Node); ok {
|
if registrationNode, ok := nodeInterface.(types.Node); ok {
|
||||||
user, err := hsdb.getUser(userName)
|
user, err := hsdb.getUser(userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -507,7 +507,7 @@ func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic)
|
||||||
defer hsdb.mu.Unlock()
|
defer hsdb.mu.Unlock()
|
||||||
|
|
||||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
if err := hsdb.db.Model(node).Updates(types.Node{
|
||||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey),
|
NodeKey: nodeKey.String(),
|
||||||
}).Error; err != nil {
|
}).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -524,7 +524,7 @@ func (hsdb *HSDatabase) NodeSetMachineKey(
|
||||||
defer hsdb.mu.Unlock()
|
defer hsdb.mu.Unlock()
|
||||||
|
|
||||||
if err := hsdb.db.Model(node).Updates(types.Node{
|
if err := hsdb.db.Model(node).Updates(types.Node{
|
||||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
MachineKey: machineKey.String(),
|
||||||
}).Error; err != nil {
|
}).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,8 +82,8 @@ func (s *Suite) TestGetNodeByNodeKey(c *check.C) {
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
MachineKey: machineKey.Public().String(),
|
||||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
NodeKey: nodeKey.Public().String(),
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
|
@ -113,8 +113,8 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
MachineKey: machineKey.Public().String(),
|
||||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
NodeKey: nodeKey.Public().String(),
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
|
@ -575,7 +575,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
NodeKey: nodeKey.Public().String(),
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Hostname: "test",
|
Hostname: "test",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
|
|
|
@ -77,9 +77,6 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "testest",
|
Hostname: "testest",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
@ -101,9 +98,6 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "testest",
|
Hostname: "testest",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
@ -138,9 +132,6 @@ func (*Suite) TestEphemeralKey(c *check.C) {
|
||||||
now := time.Now().Add(-time.Second * 30)
|
now := time.Now().Add(-time.Second * 30)
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "testest",
|
Hostname: "testest",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
|
|
@ -29,9 +29,6 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "test_get_route_node",
|
Hostname: "test_get_route_node",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
@ -80,9 +77,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "test_enable_route_node",
|
Hostname: "test_enable_route_node",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
@ -154,9 +148,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
}
|
}
|
||||||
node1 := types.Node{
|
node1 := types.Node{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "test_enable_route_node",
|
Hostname: "test_enable_route_node",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
@ -179,9 +170,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
}
|
}
|
||||||
node2 := types.Node{
|
node2 := types.Node{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "test_enable_route_node",
|
Hostname: "test_enable_route_node",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
@ -240,9 +228,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
node1 := types.Node{
|
node1 := types.Node{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "test_enable_route_node",
|
Hostname: "test_enable_route_node",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
@ -277,9 +262,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||||
}
|
}
|
||||||
node2 := types.Node{
|
node2 := types.Node{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "test_enable_route_node",
|
Hostname: "test_enable_route_node",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
@ -382,9 +364,6 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
node1 := types.Node{
|
node1 := types.Node{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "test_enable_route_node",
|
Hostname: "test_enable_route_node",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
|
|
@ -48,9 +48,6 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
@ -103,9 +100,6 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: "foo",
|
|
||||||
NodeKey: "bar",
|
|
||||||
DiscoKey: "faa",
|
|
||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: oldUser.ID,
|
UserID: oldUser.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
|
|
@ -545,7 +545,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||||
}
|
}
|
||||||
|
|
||||||
api.h.registrationCache.Set(
|
api.h.registrationCache.Set(
|
||||||
util.NodePublicKeyStripPrefix(nodeKey),
|
nodeKey.String(),
|
||||||
newNode,
|
newNode,
|
||||||
registerCacheExpiration,
|
registerCacheExpiration,
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
@ -71,7 +72,7 @@ func (h *Headscale) KeyHandler(
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
_, err := writer.Write(
|
_, err := writer.Write(
|
||||||
[]byte(util.MachinePublicKeyStripPrefix(h.privateKey2019.Public())),
|
[]byte(strings.TrimPrefix(h.privateKey2019.Public().String(), "mkey:")),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -229,7 +230,7 @@ func (h *Headscale) RegisterWebAPI(
|
||||||
// the template and log an error.
|
// the template and log an error.
|
||||||
var nodeKey key.NodePublic
|
var nodeKey key.NodePublic
|
||||||
err := nodeKey.UnmarshalText(
|
err := nodeKey.UnmarshalText(
|
||||||
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)),
|
[]byte(nodeKeyStr),
|
||||||
)
|
)
|
||||||
|
|
||||||
if !ok || nodeKeyStr == "" || err != nil {
|
if !ok || nodeKeyStr == "" || err != nil {
|
||||||
|
|
|
@ -369,7 +369,7 @@ func (m *Mapper) marshalMapResponse(
|
||||||
atomic.AddUint64(&m.seq, 1)
|
atomic.AddUint64(&m.seq, 1)
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
var machineKey key.MachinePublic
|
||||||
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)))
|
err := machineKey.UnmarshalText([]byte(node.MachineKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
|
|
@ -122,7 +122,7 @@ func (h *Headscale) RegisterOIDC(
|
||||||
// the template and log an error.
|
// the template and log an error.
|
||||||
var nodeKey key.NodePublic
|
var nodeKey key.NodePublic
|
||||||
err := nodeKey.UnmarshalText(
|
err := nodeKey.UnmarshalText(
|
||||||
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)),
|
[]byte(nodeKeyStr),
|
||||||
)
|
)
|
||||||
|
|
||||||
if !ok || nodeKeyStr == "" || err != nil {
|
if !ok || nodeKeyStr == "" || err != nil {
|
||||||
|
@ -154,7 +154,7 @@ func (h *Headscale) RegisterOIDC(
|
||||||
// place the node key into the state cache, so it can be retrieved later
|
// place the node key into the state cache, so it can be retrieved later
|
||||||
h.registrationCache.Set(
|
h.registrationCache.Set(
|
||||||
stateStr,
|
stateStr,
|
||||||
util.NodePublicKeyStripPrefix(nodeKey),
|
nodeKey,
|
||||||
registerCacheExpiration,
|
registerCacheExpiration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -479,10 +479,11 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
||||||
}
|
}
|
||||||
|
|
||||||
var nodeKey key.NodePublic
|
var nodeKey key.NodePublic
|
||||||
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
|
nodeKey, nodeKeyOK := nodeKeyIf.(key.NodePublic)
|
||||||
if !nodeKeyOK {
|
if !nodeKeyOK {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Msg("requested node state key is not a string")
|
Interface("got", nodeKeyIf).
|
||||||
|
Msg("requested node state key is not a nodekey")
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
writer.WriteHeader(http.StatusBadRequest)
|
||||||
_, err := writer.Write([]byte("state is invalid"))
|
_, err := writer.Write([]byte("state is invalid"))
|
||||||
|
@ -493,24 +494,6 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
||||||
return nil, false, errOIDCInvalidNodeState
|
return nil, false, errOIDCInvalidNodeState
|
||||||
}
|
}
|
||||||
|
|
||||||
err := nodeKey.UnmarshalText(
|
|
||||||
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("nodeKey", nodeKeyFromCache).
|
|
||||||
Bool("nodeKeyOK", nodeKeyOK).
|
|
||||||
Msg("could not parse node public key")
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, werr := writer.Write([]byte("could not parse node public key"))
|
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// retrieve node information if it exist
|
// retrieve node information if it exist
|
||||||
// The error is not important, because if it does not
|
// The error is not important, because if it does not
|
||||||
// exist, then this is a new node and we will move
|
// exist, then this is a new node and we will move
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
@ -91,7 +90,7 @@ func (h *Headscale) handlePoll(
|
||||||
node.LastSeen = &now
|
node.LastSeen = &now
|
||||||
node.Hostname = mapRequest.Hostinfo.Hostname
|
node.Hostname = mapRequest.Hostinfo.Hostname
|
||||||
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
||||||
node.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
|
node.DiscoKey = mapRequest.DiscoKey.String()
|
||||||
node.Endpoints = mapRequest.Endpoints
|
node.Endpoints = mapRequest.Endpoints
|
||||||
|
|
||||||
if err := h.db.NodeSave(node); err != nil {
|
if err := h.db.NodeSave(node); err != nil {
|
||||||
|
@ -144,7 +143,7 @@ func (h *Headscale) handlePoll(
|
||||||
node.LastSeen = &now
|
node.LastSeen = &now
|
||||||
node.Hostname = mapRequest.Hostinfo.Hostname
|
node.Hostname = mapRequest.Hostinfo.Hostname
|
||||||
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
||||||
node.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
|
node.DiscoKey = mapRequest.DiscoKey.String()
|
||||||
node.Endpoints = mapRequest.Endpoints
|
node.Endpoints = mapRequest.Endpoints
|
||||||
|
|
||||||
// When a node connects to control, list the peers it has at
|
// When a node connects to control, list the peers it has at
|
||||||
|
|
|
@ -45,7 +45,7 @@ func (h *Headscale) PollNetMapHandler(
|
||||||
body, _ := io.ReadAll(req.Body)
|
body, _ := io.ReadAll(req.Body)
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
var machineKey key.MachinePublic
|
||||||
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr)))
|
err := machineKey.UnmarshalText([]byte("mkey:" + machineKeyStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -295,7 +294,7 @@ func (node *Node) MachinePublicKey() (key.MachinePublic, error) {
|
||||||
|
|
||||||
if node.MachineKey != "" {
|
if node.MachineKey != "" {
|
||||||
err := machineKey.UnmarshalText(
|
err := machineKey.UnmarshalText(
|
||||||
[]byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)),
|
[]byte(node.MachineKey),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return key.MachinePublic{}, fmt.Errorf("failed to parse machine public key: %w", err)
|
return key.MachinePublic{}, fmt.Errorf("failed to parse machine public key: %w", err)
|
||||||
|
@ -309,7 +308,7 @@ func (node *Node) DiscoPublicKey() (key.DiscoPublic, error) {
|
||||||
var discoKey key.DiscoPublic
|
var discoKey key.DiscoPublic
|
||||||
if node.DiscoKey != "" {
|
if node.DiscoKey != "" {
|
||||||
err := discoKey.UnmarshalText(
|
err := discoKey.UnmarshalText(
|
||||||
[]byte(util.DiscoPublicKeyEnsurePrefix(node.DiscoKey)),
|
[]byte(node.DiscoKey),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return key.DiscoPublic{}, fmt.Errorf("failed to parse disco public key: %w", err)
|
return key.DiscoPublic{}, fmt.Errorf("failed to parse disco public key: %w", err)
|
||||||
|
@ -323,7 +322,7 @@ func (node *Node) DiscoPublicKey() (key.DiscoPublic, error) {
|
||||||
|
|
||||||
func (node *Node) NodePublicKey() (key.NodePublic, error) {
|
func (node *Node) NodePublicKey() (key.NodePublic, error) {
|
||||||
var nodeKey key.NodePublic
|
var nodeKey key.NodePublic
|
||||||
err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(node.NodeKey)))
|
err := nodeKey.UnmarshalText([]byte(node.NodeKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return key.NodePublic{}, fmt.Errorf("failed to parse node public key: %w", err)
|
return key.NodePublic{}, fmt.Errorf("failed to parse node public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,11 +11,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Base8 = 8
|
Base8 = 8
|
||||||
Base10 = 10
|
Base10 = 10
|
||||||
BitSize16 = 16
|
BitSize16 = 16
|
||||||
BitSize32 = 32
|
BitSize32 = 32
|
||||||
BitSize64 = 64
|
BitSize64 = 64
|
||||||
|
PermissionFallback = 0o700
|
||||||
)
|
)
|
||||||
|
|
||||||
func AbsolutePathFromConfigPath(path string) string {
|
func AbsolutePathFromConfigPath(path string) string {
|
||||||
|
|
|
@ -4,106 +4,22 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
|
|
||||||
// These constants are copied from the upstream tailscale.com/types/key
|
|
||||||
// library, because they are not exported.
|
|
||||||
// https://github.com/tailscale/tailscale/tree/main/types/key
|
|
||||||
|
|
||||||
// nodePublicHexPrefix is the prefix used to identify a
|
|
||||||
// hex-encoded node public key.
|
|
||||||
//
|
|
||||||
// This prefix is used in the control protocol, so cannot be
|
|
||||||
// changed.
|
|
||||||
nodePublicHexPrefix = "nodekey:"
|
|
||||||
|
|
||||||
// machinePublicHexPrefix is the prefix used to identify a
|
|
||||||
// hex-encoded machine public key.
|
|
||||||
//
|
|
||||||
// This prefix is used in the control protocol, so cannot be
|
|
||||||
// changed.
|
|
||||||
machinePublicHexPrefix = "mkey:"
|
|
||||||
|
|
||||||
// discoPublicHexPrefix is the prefix used to identify a
|
|
||||||
// hex-encoded disco public key.
|
|
||||||
//
|
|
||||||
// This prefix is used in the control protocol, so cannot be
|
|
||||||
// changed.
|
|
||||||
discoPublicHexPrefix = "discokey:"
|
|
||||||
|
|
||||||
// privateKey prefix.
|
|
||||||
privateHexPrefix = "privkey:"
|
|
||||||
|
|
||||||
PermissionFallback = 0o700
|
|
||||||
|
|
||||||
ZstdCompression = "zstd"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+")
|
NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+")
|
||||||
ErrCannotDecryptResponse = errors.New("cannot decrypt response")
|
ErrCannotDecryptResponse = errors.New("cannot decrypt response")
|
||||||
|
ZstdCompression = "zstd"
|
||||||
)
|
)
|
||||||
|
|
||||||
func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string {
|
|
||||||
return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string {
|
|
||||||
return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string {
|
|
||||||
return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MachinePublicKeyEnsurePrefix(machineKey string) string {
|
|
||||||
if !strings.HasPrefix(machineKey, machinePublicHexPrefix) {
|
|
||||||
return machinePublicHexPrefix + machineKey
|
|
||||||
}
|
|
||||||
|
|
||||||
return machineKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func NodePublicKeyEnsurePrefix(nodeKey string) string {
|
|
||||||
if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) {
|
|
||||||
return nodePublicHexPrefix + nodeKey
|
|
||||||
}
|
|
||||||
|
|
||||||
return nodeKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func DiscoPublicKeyEnsurePrefix(discoKey string) string {
|
|
||||||
if !strings.HasPrefix(discoKey, discoPublicHexPrefix) {
|
|
||||||
return discoPublicHexPrefix + discoKey
|
|
||||||
}
|
|
||||||
|
|
||||||
return discoKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func PrivateKeyEnsurePrefix(privateKey string) string {
|
|
||||||
if !strings.HasPrefix(privateKey, privateHexPrefix) {
|
|
||||||
return privateHexPrefix + privateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
return privateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeAndUnmarshalNaCl(
|
func DecodeAndUnmarshalNaCl(
|
||||||
msg []byte,
|
msg []byte,
|
||||||
output interface{},
|
output interface{},
|
||||||
pubKey *key.MachinePublic,
|
pubKey *key.MachinePublic,
|
||||||
privKey *key.MachinePrivate,
|
privKey *key.MachinePrivate,
|
||||||
) error {
|
) error {
|
||||||
// log.Trace().
|
|
||||||
// Str("pubkey", pubKey.ShortString()).
|
|
||||||
// Int("length", len(msg)).
|
|
||||||
// Msg("Trying to decrypt")
|
|
||||||
|
|
||||||
decrypted, ok := privKey.OpenFrom(*pubKey, msg)
|
decrypted, ok := privKey.OpenFrom(*pubKey, msg)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrCannotDecryptResponse
|
return ErrCannotDecryptResponse
|
||||||
|
|
|
@ -348,6 +348,14 @@ func (t *HeadscaleInContainer) Shutdown() error {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = t.SaveDatabase("/tmp/control")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf(
|
||||||
|
"Failed to save database from control: %s",
|
||||||
|
fmt.Errorf("failed to save database from control: %w", err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
return t.pool.Purge(t.container)
|
return t.pool.Purge(t.container)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -393,6 +401,24 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
||||||
|
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.WriteFile(
|
||||||
|
path.Join(savePath, t.hostname+".db.tar"),
|
||||||
|
tarFile,
|
||||||
|
os.ModePerm,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Execute runs a command inside the Headscale container and returns the
|
// Execute runs a command inside the Headscale container and returns the
|
||||||
// result of stdout as a string.
|
// result of stdout as a string.
|
||||||
func (t *HeadscaleInContainer) Execute(
|
func (t *HeadscaleInContainer) Execute(
|
||||||
|
|
Loading…
Reference in a new issue