Unmarshal keys in the non-deprecated way
This commit is contained in:
parent
0012c76170
commit
c38f00fab8
7 changed files with 27 additions and 40 deletions
5
api.go
5
api.go
|
@ -13,7 +13,6 @@ import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"go4.org/mem"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
@ -74,7 +73,9 @@ func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
|
||||||
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||||
body, _ := io.ReadAll(ctx.Request.Body)
|
body, _ := io.ReadAll(ctx.Request.Body)
|
||||||
machineKeyStr := ctx.Param("id")
|
machineKeyStr := ctx.Param("id")
|
||||||
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
|
|
||||||
|
var machineKey key.MachinePublic
|
||||||
|
err := machineKey.UnmarshalText([]byte(machineKeyStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/pterm/pterm"
|
"github.com/pterm/pterm"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"go4.org/mem"
|
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
@ -486,7 +485,8 @@ func nodesToPtables(
|
||||||
expiry = machine.Expiry.AsTime()
|
expiry = machine.Expiry.AsTime()
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey))
|
var nodeKey key.NodePublic
|
||||||
|
err := nodeKey.UnmarshalText([]byte(machine.NodeKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -720,6 +720,7 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
|
||||||
[]string{},
|
[]string{},
|
||||||
)
|
)
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
fmt.Println("Error: ", err)
|
||||||
|
|
||||||
var listOnlySharedMachineNamespace []v1.Machine
|
var listOnlySharedMachineNamespace []v1.Machine
|
||||||
err = json.Unmarshal(
|
err = json.Unmarshal(
|
||||||
|
@ -728,6 +729,8 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
|
||||||
)
|
)
|
||||||
assert.Nil(s.T(), err)
|
assert.Nil(s.T(), err)
|
||||||
|
|
||||||
|
fmt.Println("List: ", listOnlySharedMachineNamespaceResult)
|
||||||
|
fmt.Println("List2: ", listOnlySharedMachineNamespace)
|
||||||
assert.Len(s.T(), listOnlySharedMachineNamespace, 2)
|
assert.Len(s.T(), listOnlySharedMachineNamespace, 2)
|
||||||
|
|
||||||
assert.Equal(s.T(), uint64(6), listOnlySharedMachineNamespace[0].Id)
|
assert.Equal(s.T(), uint64(6), listOnlySharedMachineNamespace[0].Id)
|
||||||
|
|
14
machine.go
14
machine.go
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/fatih/set"
|
"github.com/fatih/set"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"go4.org/mem"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
"gorm.io/datatypes"
|
"gorm.io/datatypes"
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
|
@ -439,7 +438,8 @@ func (machine Machine) toNode(
|
||||||
dnsConfig *tailcfg.DNSConfig,
|
dnsConfig *tailcfg.DNSConfig,
|
||||||
includeRoutes bool,
|
includeRoutes bool,
|
||||||
) (*tailcfg.Node, error) {
|
) (*tailcfg.Node, error) {
|
||||||
nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey))
|
var nodeKey key.NodePublic
|
||||||
|
err := nodeKey.UnmarshalText([]byte(machine.NodeKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -449,19 +449,18 @@ func (machine Machine) toNode(
|
||||||
return nil, fmt.Errorf("failed to parse node public key: %w", err)
|
return nil, fmt.Errorf("failed to parse node public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machine.MachineKey))
|
var machineKey key.MachinePublic
|
||||||
|
err = machineKey.UnmarshalText([]byte(machine.MachineKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
|
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var discoKey key.DiscoPublic
|
var discoKey key.DiscoPublic
|
||||||
if machine.DiscoKey != "" {
|
if machine.DiscoKey != "" {
|
||||||
dKey := key.DiscoPublic{}
|
err := discoKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey))
|
||||||
err := dKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse disco public key: %w", err)
|
return nil, fmt.Errorf("failed to parse disco public key: %w", err)
|
||||||
}
|
}
|
||||||
discoKey = key.DiscoPublic(dKey)
|
|
||||||
} else {
|
} else {
|
||||||
discoKey = key.DiscoPublic{}
|
discoKey = key.DiscoPublic{}
|
||||||
}
|
}
|
||||||
|
@ -634,7 +633,8 @@ func (h *Headscale) RegisterMachine(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
|
var machineKey key.MachinePublic
|
||||||
|
err = machineKey.UnmarshalText([]byte(machineKeyStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
4
oidc.go
4
oidc.go
|
@ -15,7 +15,6 @@ import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"go4.org/mem"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
@ -192,7 +191,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
|
|
||||||
machineKeyStr, machineKeyOK := machineKeyIf.(string)
|
machineKeyStr, machineKeyOK := machineKeyIf.(string)
|
||||||
|
|
||||||
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
|
var machineKey key.MachinePublic
|
||||||
|
err = machineKey.UnmarshalText([]byte(machineKeyStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msg("could not parse machine public key")
|
Msg("could not parse machine public key")
|
||||||
|
|
19
poll.go
19
poll.go
|
@ -9,7 +9,6 @@ import (
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"go4.org/mem"
|
|
||||||
"gorm.io/datatypes"
|
"gorm.io/datatypes"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -36,8 +35,10 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
|
||||||
Str("id", ctx.Param("id")).
|
Str("id", ctx.Param("id")).
|
||||||
Msg("PollNetMapHandler called")
|
Msg("PollNetMapHandler called")
|
||||||
body, _ := io.ReadAll(ctx.Request.Body)
|
body, _ := io.ReadAll(ctx.Request.Body)
|
||||||
mKeyStr := ctx.Param("id")
|
machineKeyStr := ctx.Param("id")
|
||||||
mKey, err := key.ParseMachinePublicUntyped(mem.S(mKeyStr))
|
|
||||||
|
var machineKey key.MachinePublic
|
||||||
|
err := machineKey.UnmarshalText([]byte(machineKeyStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
|
@ -48,7 +49,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req := tailcfg.MapRequest{}
|
req := tailcfg.MapRequest{}
|
||||||
err = decode(body, &req, &mKey, h.privateKey)
|
err = decode(body, &req, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
|
@ -59,19 +60,19 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
machine, err := h.GetMachineByMachineKey(mKey)
|
machine, err := h.GetMachineByMachineKey(machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Msgf("Ignoring request, cannot find machine with key %s", mKey.String())
|
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
|
||||||
ctx.String(http.StatusUnauthorized, "")
|
ctx.String(http.StatusUnauthorized, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.String())
|
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
|
||||||
ctx.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
}
|
}
|
||||||
log.Trace().
|
log.Trace().
|
||||||
|
@ -101,7 +102,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
|
||||||
}
|
}
|
||||||
h.db.Save(&machine)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
data, err := h.getMapResponse(mKey, req, machine)
|
data, err := h.getMapResponse(machineKey, req, machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
|
@ -206,7 +207,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
|
||||||
ctx,
|
ctx,
|
||||||
machine,
|
machine,
|
||||||
req,
|
req,
|
||||||
mKey,
|
machineKey,
|
||||||
pollDataChan,
|
pollDataChan,
|
||||||
keepAliveChan,
|
keepAliveChan,
|
||||||
updateChan,
|
updateChan,
|
||||||
|
|
18
utils.go
18
utils.go
|
@ -20,22 +20,12 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
errCannotDecryptReponse = Error("cannot decrypt response")
|
errCannotDecryptReponse = Error("cannot decrypt response")
|
||||||
errResponseMissingNonce = Error("response missing nonce")
|
|
||||||
errCouldNotAllocateIP = Error("could not find any suitable IP")
|
errCouldNotAllocateIP = Error("could not find any suitable IP")
|
||||||
|
|
||||||
// These constants are copied from the upstream tailscale.com/types/key
|
// These constants are copied from the upstream tailscale.com/types/key
|
||||||
// library, because they are not exported.
|
// library, because they are not exported.
|
||||||
// https://github.com/tailscale/tailscale/tree/main/types/key
|
// https://github.com/tailscale/tailscale/tree/main/types/key
|
||||||
|
|
||||||
// nodePrivateHexPrefix is the prefix used to identify a
|
|
||||||
// hex-encoded node private key.
|
|
||||||
//
|
|
||||||
// This prefix name is a little unfortunate, in that it comes from
|
|
||||||
// WireGuard's own key types, and we've used it for both key types
|
|
||||||
// we persist to disk (machine and node keys). But we're stuck
|
|
||||||
// with it for now, barring another round of tricky migration.
|
|
||||||
nodePrivateHexPrefix = "privkey:"
|
|
||||||
|
|
||||||
// nodePublicHexPrefix is the prefix used to identify a
|
// nodePublicHexPrefix is the prefix used to identify a
|
||||||
// hex-encoded node public key.
|
// hex-encoded node public key.
|
||||||
//
|
//
|
||||||
|
@ -43,14 +33,6 @@ const (
|
||||||
// changed.
|
// changed.
|
||||||
nodePublicHexPrefix = "nodekey:"
|
nodePublicHexPrefix = "nodekey:"
|
||||||
|
|
||||||
// machinePrivateHexPrefix is the prefix used to identify a
|
|
||||||
// hex-encoded machine private key.
|
|
||||||
//
|
|
||||||
// This prefix name is a little unfortunate, in that it comes from
|
|
||||||
// WireGuard's own key types. Unfortunately we're stuck with it for
|
|
||||||
// machine keys, because we serialize them to disk with this prefix.
|
|
||||||
machinePrivateHexPrefix = "privkey:"
|
|
||||||
|
|
||||||
// machinePublicHexPrefix is the prefix used to identify a
|
// machinePublicHexPrefix is the prefix used to identify a
|
||||||
// hex-encoded machine public key.
|
// hex-encoded machine public key.
|
||||||
//
|
//
|
||||||
|
|
Loading…
Reference in a new issue