diff --git a/api.go b/api.go index 84b79a5..ff73871 100644 --- a/api.go +++ b/api.go @@ -13,7 +13,6 @@ import ( "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" - "go4.org/mem" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -74,7 +73,9 @@ func (h *Headscale) RegisterWebAPI(ctx *gin.Context) { func (h *Headscale) RegistrationHandler(ctx *gin.Context) { body, _ := io.ReadAll(ctx.Request.Body) machineKeyStr := ctx.Param("id") - machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr)) + + var machineKey key.MachinePublic + err := machineKey.UnmarshalText([]byte(machineKeyStr)) if err != nil { log.Error(). Caller(). diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 2c7fa9b..5adc7f5 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -11,7 +11,6 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/pterm/pterm" "github.com/spf13/cobra" - "go4.org/mem" "google.golang.org/grpc/status" "tailscale.com/types/key" ) @@ -486,7 +485,8 @@ func nodesToPtables( expiry = machine.Expiry.AsTime() } - nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey)) + var nodeKey key.NodePublic + err := nodeKey.UnmarshalText([]byte(machine.NodeKey)) if err != nil { return nil, err } diff --git a/integration_cli_test.go b/integration_cli_test.go index ee94054..eb55322 100644 --- a/integration_cli_test.go +++ b/integration_cli_test.go @@ -720,6 +720,7 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() { []string{}, ) assert.Nil(s.T(), err) + fmt.Println("Error: ", err) var listOnlySharedMachineNamespace []v1.Machine err = json.Unmarshal( @@ -728,6 +729,8 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() { ) assert.Nil(s.T(), err) + fmt.Println("List: ", listOnlySharedMachineNamespaceResult) + fmt.Println("List2: ", listOnlySharedMachineNamespace) assert.Len(s.T(), listOnlySharedMachineNamespace, 2) assert.Equal(s.T(), uint64(6), listOnlySharedMachineNamespace[0].Id) diff --git a/machine.go b/machine.go index 26c59c6..03caa5a 100644 --- a/machine.go +++ b/machine.go @@ -12,7 +12,6 @@ import ( "github.com/fatih/set" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/rs/zerolog/log" - "go4.org/mem" "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/datatypes" "inet.af/netaddr" @@ -439,7 +438,8 @@ func (machine Machine) toNode( dnsConfig *tailcfg.DNSConfig, includeRoutes bool, ) (*tailcfg.Node, error) { - nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey)) + var nodeKey key.NodePublic + err := nodeKey.UnmarshalText([]byte(machine.NodeKey)) if err != nil { log.Trace(). Caller(). @@ -449,19 +449,18 @@ func (machine Machine) toNode( return nil, fmt.Errorf("failed to parse node public key: %w", err) } - machineKey, err := key.ParseMachinePublicUntyped(mem.S(machine.MachineKey)) + var machineKey key.MachinePublic + err = machineKey.UnmarshalText([]byte(machine.MachineKey)) if err != nil { return nil, fmt.Errorf("failed to parse machine public key: %w", err) } var discoKey key.DiscoPublic if machine.DiscoKey != "" { - dKey := key.DiscoPublic{} - err := dKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey)) + err := discoKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey)) if err != nil { return nil, fmt.Errorf("failed to parse disco public key: %w", err) } - discoKey = key.DiscoPublic(dKey) } else { discoKey = key.DiscoPublic{} } @@ -634,7 +633,8 @@ func (h *Headscale) RegisterMachine( return nil, err } - machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr)) + var machineKey key.MachinePublic + err = machineKey.UnmarshalText([]byte(machineKeyStr)) if err != nil { return nil, err } diff --git a/oidc.go b/oidc.go index 02666d8..48ad718 100644 --- a/oidc.go +++ b/oidc.go @@ -15,7 +15,6 @@ import ( "github.com/gin-gonic/gin" "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" - "go4.org/mem" "golang.org/x/oauth2" "gorm.io/gorm" "tailscale.com/types/key" @@ -192,7 +191,8 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { machineKeyStr, machineKeyOK := machineKeyIf.(string) - machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr)) + var machineKey key.MachinePublic + err = machineKey.UnmarshalText([]byte(machineKeyStr)) if err != nil { log.Error(). Msg("could not parse machine public key") diff --git a/poll.go b/poll.go index b993fa7..70bacc6 100644 --- a/poll.go +++ b/poll.go @@ -9,7 +9,6 @@ import ( "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" - "go4.org/mem" "gorm.io/datatypes" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -36,8 +35,10 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Str("id", ctx.Param("id")). Msg("PollNetMapHandler called") body, _ := io.ReadAll(ctx.Request.Body) - mKeyStr := ctx.Param("id") - mKey, err := key.ParseMachinePublicUntyped(mem.S(mKeyStr)) + machineKeyStr := ctx.Param("id") + + var machineKey key.MachinePublic + err := machineKey.UnmarshalText([]byte(machineKeyStr)) if err != nil { log.Error(). Str("handler", "PollNetMap"). @@ -48,7 +49,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { return } req := tailcfg.MapRequest{} - err = decode(body, &req, &mKey, h.privateKey) + err = decode(body, &req, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "PollNetMap"). @@ -59,19 +60,19 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { return } - machine, err := h.GetMachineByMachineKey(mKey) + machine, err := h.GetMachineByMachineKey(machineKey) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Warn(). Str("handler", "PollNetMap"). - Msgf("Ignoring request, cannot find machine with key %s", mKey.String()) + Msgf("Ignoring request, cannot find machine with key %s", machineKey.String()) ctx.String(http.StatusUnauthorized, "") return } log.Error(). Str("handler", "PollNetMap"). - Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.String()) + Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String()) ctx.String(http.StatusInternalServerError, "") } log.Trace(). @@ -101,7 +102,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { } h.db.Save(&machine) - data, err := h.getMapResponse(mKey, req, machine) + data, err := h.getMapResponse(machineKey, req, machine) if err != nil { log.Error(). Str("handler", "PollNetMap"). @@ -206,7 +207,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { ctx, machine, req, - mKey, + machineKey, pollDataChan, keepAliveChan, updateChan, diff --git a/utils.go b/utils.go index d0fa1b5..fa9f028 100644 --- a/utils.go +++ b/utils.go @@ -20,22 +20,12 @@ import ( const ( errCannotDecryptReponse = Error("cannot decrypt response") - errResponseMissingNonce = Error("response missing nonce") errCouldNotAllocateIP = Error("could not find any suitable IP") // These constants are copied from the upstream tailscale.com/types/key // library, because they are not exported. // https://github.com/tailscale/tailscale/tree/main/types/key - // nodePrivateHexPrefix is the prefix used to identify a - // hex-encoded node private key. - // - // This prefix name is a little unfortunate, in that it comes from - // WireGuard's own key types, and we've used it for both key types - // we persist to disk (machine and node keys). But we're stuck - // with it for now, barring another round of tricky migration. - nodePrivateHexPrefix = "privkey:" - // nodePublicHexPrefix is the prefix used to identify a // hex-encoded node public key. // @@ -43,14 +33,6 @@ const ( // changed. nodePublicHexPrefix = "nodekey:" - // machinePrivateHexPrefix is the prefix used to identify a - // hex-encoded machine private key. - // - // This prefix name is a little unfortunate, in that it comes from - // WireGuard's own key types. Unfortunately we're stuck with it for - // machine keys, because we serialize them to disk with this prefix. - machinePrivateHexPrefix = "privkey:" - // machinePublicHexPrefix is the prefix used to identify a // hex-encoded machine public key. //