diff --git a/api.go b/api.go index 1a51fdf..84b79a5 100644 --- a/api.go +++ b/api.go @@ -13,9 +13,10 @@ import ( "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" + "go4.org/mem" "gorm.io/gorm" "tailscale.com/tailcfg" - "tailscale.com/types/wgkey" + "tailscale.com/types/key" ) const ( @@ -34,7 +35,7 @@ func (h *Headscale) KeyHandler(ctx *gin.Context) { ctx.Data( http.StatusOK, "text/plain; charset=utf-8", - []byte(h.publicKey.HexString()), + []byte(MachinePublicKeyStripPrefix(*h.publicKey)), ) } @@ -73,10 +74,10 @@ func (h *Headscale) RegisterWebAPI(ctx *gin.Context) { func (h *Headscale) RegistrationHandler(ctx *gin.Context) { body, _ := io.ReadAll(ctx.Request.Body) machineKeyStr := ctx.Param("id") - machineKey, err := wgkey.ParseHex(machineKeyStr) + machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr)) if err != nil { log.Error(). - Str("handler", "Registration"). + Caller(). Err(err). Msg("Cannot parse machine key") machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() @@ -88,7 +89,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { err = decode(body, &req, &machineKey, h.privateKey) if err != nil { log.Error(). - Str("handler", "Registration"). + Caller(). Err(err). Msg("Cannot decode message") machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() @@ -98,17 +99,17 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { } now := time.Now().UTC() - machine, err := h.GetMachineByMachineKey(machineKey.HexString()) + machine, err := h.GetMachineByMachineKey(machineKey) if errors.Is(err, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") newMachine := Machine{ Expiry: &time.Time{}, - MachineKey: machineKey.HexString(), + MachineKey: MachinePublicKeyStripPrefix(machineKey), Name: req.Hostinfo.Hostname, } if err := h.db.Create(&newMachine).Error; err != nil { log.Error(). - Str("handler", "Registration"). + Caller(). Err(err). Msg("Could not create row") machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name). @@ -125,7 +126,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { // - Trying to log out (sending a expiry in the past) // - A valid, registered machine, looking for the node map // - Expired machine wanting to reauthenticate - if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() { + if machine.NodeKey == req.NodeKey.String() { // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { @@ -144,7 +145,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { } // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration - if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && + if machine.NodeKey == req.OldNodeKey.String() && !machine.isExpired() { h.handleMachineRefreshKey(ctx, machineKey, req, *machine) @@ -168,7 +169,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { } func (h *Headscale) getMapResponse( - machineKey wgkey.Key, + machineKey key.MachinePublic, req tailcfg.MapRequest, machine *Machine, ) ([]byte, error) { @@ -179,6 +180,7 @@ func (h *Headscale) getMapResponse( node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true) if err != nil { log.Error(). + Caller(). Str("func", "getMapResponse"). Err(err). Msg("Cannot convert to node") @@ -189,6 +191,7 @@ func (h *Headscale) getMapResponse( peers, err := h.getValidPeers(machine) if err != nil { log.Error(). + Caller(). Str("func", "getMapResponse"). Err(err). Msg("Cannot fetch peers") @@ -201,6 +204,7 @@ func (h *Headscale) getMapResponse( nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true) if err != nil { log.Error(). + Caller(). Str("func", "getMapResponse"). Err(err). Msg("Failed to convert peers to Tailscale nodes") @@ -238,10 +242,7 @@ func (h *Headscale) getMapResponse( encoder, _ := zstd.NewWriter(nil) srcCompressed := encoder.EncodeAll(src, nil) - respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey) - if err != nil { - return nil, err - } + respBody = h.privateKey.SealTo(machineKey, srcCompressed) } else { respBody, err = encode(resp, &machineKey, h.privateKey) if err != nil { @@ -257,7 +258,7 @@ func (h *Headscale) getMapResponse( } func (h *Headscale) getMapKeepAliveResponse( - machineKey wgkey.Key, + machineKey key.MachinePublic, mapRequest tailcfg.MapRequest, ) ([]byte, error) { mapResponse := tailcfg.MapResponse{ @@ -269,10 +270,7 @@ func (h *Headscale) getMapKeepAliveResponse( src, _ := json.Marshal(mapResponse) encoder, _ := zstd.NewWriter(nil) srcCompressed := encoder.EncodeAll(src, nil) - respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey) - if err != nil { - return nil, err - } + respBody = h.privateKey.SealTo(machineKey, srcCompressed) } else { respBody, err = encode(mapResponse, &machineKey, h.privateKey) if err != nil { @@ -288,13 +286,12 @@ func (h *Headscale) getMapKeepAliveResponse( func (h *Headscale) handleMachineLogOut( ctx *gin.Context, - machineKey wgkey.Key, + machineKey key.MachinePublic, machine Machine, ) { resp := tailcfg.RegisterResponse{} log.Info(). - Str("handler", "Registration"). Str("machine", machine.Name). Msg("Client requested logout") @@ -306,7 +303,7 @@ func (h *Headscale) handleMachineLogOut( respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). - Str("handler", "Registration"). + Caller(). Err(err). Msg("Cannot encode message") ctx.String(http.StatusInternalServerError, "") @@ -318,14 +315,13 @@ func (h *Headscale) handleMachineLogOut( func (h *Headscale) handleMachineValidRegistration( ctx *gin.Context, - machineKey wgkey.Key, + machineKey key.MachinePublic, machine Machine, ) { resp := tailcfg.RegisterResponse{} // The machine registration is valid, respond with redirect to /map log.Debug(). - Str("handler", "Registration"). Str("machine", machine.Name). Msg("Client is registered and we have the current NodeKey. All clear to /map") @@ -337,7 +333,7 @@ func (h *Headscale) handleMachineValidRegistration( respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). - Str("handler", "Registration"). + Caller(). Err(err). Msg("Cannot encode message") machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name). @@ -353,7 +349,7 @@ func (h *Headscale) handleMachineValidRegistration( func (h *Headscale) handleMachineExpired( ctx *gin.Context, - machineKey wgkey.Key, + machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, machine Machine, ) { @@ -361,7 +357,6 @@ func (h *Headscale) handleMachineExpired( // The client has registered before, but has expired log.Debug(). - Str("handler", "Registration"). Str("machine", machine.Name). Msg("Machine registration has expired. Sending a authurl to register") @@ -373,16 +368,16 @@ func (h *Headscale) handleMachineExpired( if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString()) + strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString()) + strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String()) } respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). - Str("handler", "Registration"). + Caller(). Err(err). Msg("Cannot encode message") machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name). @@ -398,17 +393,16 @@ func (h *Headscale) handleMachineExpired( func (h *Headscale) handleMachineRefreshKey( ctx *gin.Context, - machineKey wgkey.Key, + machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, machine Machine, ) { resp := tailcfg.RegisterResponse{} log.Debug(). - Str("handler", "Registration"). Str("machine", machine.Name). Msg("We have the OldNodeKey in the database. This is a key refresh") - machine.NodeKey = wgkey.Key(registerRequest.NodeKey).HexString() + machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) h.db.Save(&machine) resp.AuthURL = "" @@ -416,7 +410,7 @@ func (h *Headscale) handleMachineRefreshKey( respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). - Str("handler", "Registration"). + Caller(). Err(err). Msg("Cannot encode message") ctx.String(http.StatusInternalServerError, "Extremely sad!") @@ -428,7 +422,7 @@ func (h *Headscale) handleMachineRefreshKey( func (h *Headscale) handleMachineRegistrationNew( ctx *gin.Context, - machineKey wgkey.Key, + machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, machine Machine, ) { @@ -436,18 +430,17 @@ func (h *Headscale) handleMachineRegistrationNew( // The machine registration is new, redirect the client to the registration URL log.Debug(). - Str("handler", "Registration"). Str("machine", machine.Name). Msg("The node is sending us a new NodeKey, sending auth url") if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf( "%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), - machineKey.HexString(), + machineKey.String(), ) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString()) + strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) } if !registerRequest.Expiry.IsZero() { @@ -457,19 +450,21 @@ func (h *Headscale) handleMachineRegistrationNew( Time("expiry", registerRequest.Expiry). Msg("Non-zero expiry time requested, adding to cache") h.requestedExpiryCache.Set( - machineKey.HexString(), + machineKey.String(), registerRequest.Expiry, requestedExpiryCacheExpiration, ) } - machine.NodeKey = wgkey.Key(registerRequest.NodeKey).HexString() // save the NodeKey + machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) + + // save the NodeKey h.db.Save(&machine) respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). - Str("handler", "Registration"). + Caller(). Err(err). Msg("Cannot encode message") ctx.String(http.StatusInternalServerError, "") @@ -481,7 +476,7 @@ func (h *Headscale) handleMachineRegistrationNew( func (h *Headscale) handleAuthKey( ctx *gin.Context, - machineKey wgkey.Key, + machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, machine Machine, ) { @@ -493,6 +488,7 @@ func (h *Headscale) handleAuthKey( pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey) if err != nil { log.Error(). + Caller(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Err(err). @@ -501,6 +497,7 @@ func (h *Headscale) handleAuthKey( respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). + Caller(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Err(err). @@ -513,6 +510,7 @@ func (h *Headscale) handleAuthKey( } ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) log.Error(). + Caller(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Msg("Failed authentication via AuthKey") @@ -537,6 +535,7 @@ func (h *Headscale) handleAuthKey( ip, err := h.getAvailableIP() if err != nil { log.Error(). + Caller(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Msg("Failed to find an available IP") @@ -555,9 +554,9 @@ func (h *Headscale) handleAuthKey( machine.AuthKeyID = uint(pak.ID) machine.IPAddress = ip.String() machine.NamespaceID = pak.NamespaceID - machine.NodeKey = wgkey.Key(registerRequest.NodeKey). - HexString() - // we update it just in case + + machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) + // we update it just in case machine.Registered = true machine.RegisterMethod = RegisterMethodAuthKey h.db.Save(&machine) @@ -571,6 +570,7 @@ func (h *Headscale) handleAuthKey( respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). + Caller(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Err(err). diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index f5117d9..2c7fa9b 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -11,9 +11,9 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/pterm/pterm" "github.com/spf13/cobra" + "go4.org/mem" "google.golang.org/grpc/status" - "tailscale.com/tailcfg" - "tailscale.com/types/wgkey" + "tailscale.com/types/key" ) func init() { @@ -486,11 +486,10 @@ func nodesToPtables( expiry = machine.Expiry.AsTime() } - nKey, err := wgkey.ParseHex(machine.NodeKey) + nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey)) if err != nil { return nil, err } - nodeKey := tailcfg.NodeKey(nKey) var online string if lastSeen.After( diff --git a/machine.go b/machine.go index 306b3a4..26c59c6 100644 --- a/machine.go +++ b/machine.go @@ -12,12 +12,12 @@ import ( "github.com/fatih/set" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/rs/zerolog/log" + "go4.org/mem" "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/datatypes" - "gorm.io/gorm" "inet.af/netaddr" "tailscale.com/tailcfg" - "tailscale.com/types/wgkey" + "tailscale.com/types/key" ) const ( @@ -260,9 +260,11 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { } // GetMachineByMachineKey finds a Machine by ID and returns the Machine struct. -func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) { +func (h *Headscale) GetMachineByMachineKey( + machineKey key.MachinePublic, +) (*Machine, error) { m := Machine{} - if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil { + if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", MachinePublicKeyStripPrefix(machineKey)); result.Error != nil { return nil, result.Error } @@ -437,25 +439,31 @@ func (machine Machine) toNode( dnsConfig *tailcfg.DNSConfig, includeRoutes bool, ) (*tailcfg.Node, error) { - nodeKey, err := wgkey.ParseHex(machine.NodeKey) + nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey)) if err != nil { - return nil, err + log.Trace(). + Caller(). + Str("node_key", machine.NodeKey). + Msgf("Failed to parse node public key from hex") + + return nil, fmt.Errorf("failed to parse node public key: %w", err) } - machineKey, err := wgkey.ParseHex(machine.MachineKey) + machineKey, err := key.ParseMachinePublicUntyped(mem.S(machine.MachineKey)) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse machine public key: %w", err) } - var discoKey tailcfg.DiscoKey + var discoKey key.DiscoPublic if machine.DiscoKey != "" { - dKey, err := wgkey.ParseHex(machine.DiscoKey) + dKey := key.DiscoPublic{} + err := dKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey)) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse disco public key: %w", err) } - discoKey = tailcfg.DiscoKey(dKey) + discoKey = key.DiscoPublic(dKey) } else { - discoKey = tailcfg.DiscoKey{} + discoKey = key.DiscoPublic{} } addrs := []netaddr.IPPrefix{} @@ -555,9 +563,9 @@ func (machine Machine) toNode( ), // in headscale, unlike tailcontrol server, IDs are permanent Name: hostname, User: tailcfg.UserID(machine.NamespaceID), - Key: tailcfg.NodeKey(nodeKey), + Key: nodeKey, KeyExpiry: keyExpiry, - Machine: tailcfg.MachineKey(machineKey), + Machine: machineKey, DiscoKey: discoKey, Addresses: addrs, AllowedIPs: allowedIPs, @@ -618,31 +626,35 @@ func (machine *Machine) toProto() *v1.Machine { // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. func (h *Headscale) RegisterMachine( - key string, + machineKeyStr string, namespaceName string, ) (*Machine, error) { namespace, err := h.GetNamespace(namespaceName) if err != nil { return nil, err } - machineKey, err := wgkey.ParseHex(key) + + machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr)) if err != nil { return nil, err } - machine := Machine{} - if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is( - result.Error, - gorm.ErrRecordNotFound, - ) { - return nil, errMachineNotFound + log.Trace(). + Caller(). + Str("machine_key_str", machineKeyStr). + Str("machine_key", machineKey.String()). + Msg("Registering machine") + + machine, err := h.GetMachineByMachineKey(machineKey) + if err != nil { + return nil, err } // TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set // This means that if a user is to slow with register a machine, it will possibly not // have the correct expiry. requestedTime := time.Time{} - if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.HexString()); found { + if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found { log.Trace(). Caller(). Str("machine", machine.Name). @@ -658,9 +670,9 @@ func (h *Headscale) RegisterMachine( Str("machine", machine.Name). Msg("machine already registered, reauthenticating") - h.RefreshMachine(&machine, requestedTime) + h.RefreshMachine(machine, requestedTime) - return &machine, nil + return machine, nil } log.Trace(). @@ -709,7 +721,7 @@ func (h *Headscale) RegisterMachine( Str("ip", ip.String()). Msg("Machine registered with the database") - return &machine, nil + return machine, nil } func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { diff --git a/oidc.go b/oidc.go index 9b0a308..02666d8 100644 --- a/oidc.go +++ b/oidc.go @@ -15,8 +15,10 @@ import ( "github.com/gin-gonic/gin" "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" + "go4.org/mem" "golang.org/x/oauth2" "gorm.io/gorm" + "tailscale.com/types/key" ) const ( @@ -187,7 +189,17 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - machineKey, machineKeyOK := machineKeyIf.(string) + + machineKeyStr, machineKeyOK := machineKeyIf.(string) + + machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr)) + if err != nil { + log.Error(). + Msg("could not parse machine public key") + ctx.String(http.StatusBadRequest, "could not parse public key") + + return + } if !machineKeyOK { log.Error().Msg("could not get machine key from cache") @@ -201,7 +213,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { // TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set requestedTime := time.Time{} - if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey); found { + if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found { if reqTime, ok := requestedTimeIf.(time.Time); ok { requestedTime = reqTime } diff --git a/oidc_test.go b/oidc_test.go index 21a4357..db581b9 100644 --- a/oidc_test.go +++ b/oidc_test.go @@ -9,7 +9,7 @@ import ( "golang.org/x/oauth2" "gorm.io/gorm" "tailscale.com/tailcfg" - "tailscale.com/types/wgkey" + "tailscale.com/types/key" ) func TestHeadscale_getNamespaceFromEmail(t *testing.T) { @@ -19,8 +19,8 @@ func TestHeadscale_getNamespaceFromEmail(t *testing.T) { dbString string dbType string dbDebug bool - publicKey *wgkey.Key - privateKey *wgkey.Private + publicKey *key.MachinePublic + privateKey *key.MachinePrivate aclPolicy *ACLPolicy aclRules []tailcfg.FilterRule lastStateChange sync.Map diff --git a/poll.go b/poll.go index 9cf14e7..b993fa7 100644 --- a/poll.go +++ b/poll.go @@ -9,10 +9,11 @@ import ( "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" + "go4.org/mem" "gorm.io/datatypes" "gorm.io/gorm" "tailscale.com/tailcfg" - "tailscale.com/types/wgkey" + "tailscale.com/types/key" ) const ( @@ -36,7 +37,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Msg("PollNetMapHandler called") body, _ := io.ReadAll(ctx.Request.Body) mKeyStr := ctx.Param("id") - mKey, err := wgkey.ParseHex(mKeyStr) + mKey, err := key.ParseMachinePublicUntyped(mem.S(mKeyStr)) if err != nil { log.Error(). Str("handler", "PollNetMap"). @@ -58,19 +59,19 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { return } - machine, err := h.GetMachineByMachineKey(mKey.HexString()) + machine, err := h.GetMachineByMachineKey(mKey) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Warn(). Str("handler", "PollNetMap"). - Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString()) + Msgf("Ignoring request, cannot find machine with key %s", mKey.String()) ctx.String(http.StatusUnauthorized, "") return } log.Error(). Str("handler", "PollNetMap"). - Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString()) + Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.String()) ctx.String(http.StatusInternalServerError, "") } log.Trace(). @@ -82,7 +83,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { hostinfo, _ := json.Marshal(req.Hostinfo) machine.Name = req.Hostinfo.Hostname machine.HostInfo = datatypes.JSON(hostinfo) - machine.DiscoKey = wgkey.Key(req.DiscoKey).HexString() + machine.DiscoKey = DiscoPublicKeyStripPrefix(req.DiscoKey) now := time.Now().UTC() // From Tailscale client: @@ -225,7 +226,7 @@ func (h *Headscale) PollNetMapStream( ctx *gin.Context, machine *Machine, mapRequest tailcfg.MapRequest, - machineKey wgkey.Key, + machineKey key.MachinePublic, pollDataChan chan []byte, keepAliveChan chan []byte, updateChan chan struct{}, @@ -491,7 +492,7 @@ func (h *Headscale) scheduledPollWorker( cancelChan <-chan struct{}, updateChan chan<- struct{}, keepAliveChan chan<- []byte, - machineKey wgkey.Key, + machineKey key.MachinePublic, mapRequest tailcfg.MapRequest, machine *Machine, ) { diff --git a/utils.go b/utils.go index 9f7849e..d0fa1b5 100644 --- a/utils.go +++ b/utils.go @@ -7,50 +7,95 @@ package headscale import ( "context" - "crypto/rand" "encoding/json" "fmt" - "io" "net" "strings" - "golang.org/x/crypto/nacl/box" + "github.com/rs/zerolog/log" "inet.af/netaddr" "tailscale.com/tailcfg" - "tailscale.com/types/wgkey" + "tailscale.com/types/key" ) const ( errCannotDecryptReponse = Error("cannot decrypt response") errResponseMissingNonce = Error("response missing nonce") errCouldNotAllocateIP = Error("could not find any suitable IP") + + // These constants are copied from the upstream tailscale.com/types/key + // library, because they are not exported. + // https://github.com/tailscale/tailscale/tree/main/types/key + + // nodePrivateHexPrefix is the prefix used to identify a + // hex-encoded node private key. + // + // This prefix name is a little unfortunate, in that it comes from + // WireGuard's own key types, and we've used it for both key types + // we persist to disk (machine and node keys). But we're stuck + // with it for now, barring another round of tricky migration. + nodePrivateHexPrefix = "privkey:" + + // nodePublicHexPrefix is the prefix used to identify a + // hex-encoded node public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + nodePublicHexPrefix = "nodekey:" + + // machinePrivateHexPrefix is the prefix used to identify a + // hex-encoded machine private key. + // + // This prefix name is a little unfortunate, in that it comes from + // WireGuard's own key types. Unfortunately we're stuck with it for + // machine keys, because we serialize them to disk with this prefix. + machinePrivateHexPrefix = "privkey:" + + // machinePublicHexPrefix is the prefix used to identify a + // hex-encoded machine public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + machinePublicHexPrefix = "mkey:" + + // discoPublicHexPrefix is the prefix used to identify a + // hex-encoded disco public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + discoPublicHexPrefix = "discokey:" ) +func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string { + return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix) +} + +func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string { + return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix) +} + +func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string { + return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix) +} + // Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors type Error string func (e Error) Error() string { return string(e) } func decode( - msg []byte, - v interface{}, - pubKey *wgkey.Key, - privKey *wgkey.Private, -) error { - return decodeMsg(msg, v, pubKey, privKey) -} - -func decodeMsg( msg []byte, output interface{}, - pubKey *wgkey.Key, - privKey *wgkey.Private, + pubKey *key.MachinePublic, + privKey *key.MachinePrivate, ) error { - decrypted, err := decryptMsg(msg, pubKey, privKey) - if err != nil { - return err + log.Trace().Int("length", len(msg)).Msg("Trying to decrypt") + + decrypted, ok := privKey.OpenFrom(*pubKey, msg) + if !ok { + return errCannotDecryptReponse } - // fmt.Println(string(decrypted)) + if err := json.Unmarshal(decrypted, output); err != nil { return err } @@ -58,45 +103,17 @@ func decodeMsg( return nil } -func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) { - var nonce [24]byte - if len(msg) < len(nonce)+1 { - return nil, errResponseMissingNonce - } - copy(nonce[:], msg) - msg = msg[len(nonce):] - - pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey) - decrypted, ok := box.Open(nil, msg, &nonce, pub, pri) - if !ok { - return nil, errCannotDecryptReponse - } - - return decrypted, nil -} - -func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) { +func encode( + v interface{}, + pubKey *key.MachinePublic, + privKey *key.MachinePrivate, +) ([]byte, error) { b, err := json.Marshal(v) if err != nil { return nil, err } - return encodeMsg(b, pubKey, privKey) -} - -func encodeMsg( - payload []byte, - pubKey *wgkey.Key, - privKey *wgkey.Private, -) ([]byte, error) { - var nonce [24]byte - if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { - panic(err) - } - pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey) - msg := box.Seal(nonce[:], payload, &nonce, pub, pri) - - return msg, nil + return privKey.SealTo(*pubKey, b), nil } func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {