Factor wgkey to types/key
This commit converts all the uses of wgkey to the new key interfaces. It now has specific machine, node and discovery keys and we now should use them correctly. Please note the new logic which strips a key prefix (in utils.go) that is now standard inside tailscale. In theory we could put it in the database, but to preserve backwards compatibility and not spend a lot of resources on accounting for both, we just strip them.
This commit is contained in:
parent
07418140a2
commit
cfd53bc4aa
7 changed files with 184 additions and 143 deletions
92
api.go
92
api.go
|
@ -13,9 +13,10 @@ import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"go4.org/mem"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/wgkey"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -34,7 +35,7 @@ func (h *Headscale) KeyHandler(ctx *gin.Context) {
|
||||||
ctx.Data(
|
ctx.Data(
|
||||||
http.StatusOK,
|
http.StatusOK,
|
||||||
"text/plain; charset=utf-8",
|
"text/plain; charset=utf-8",
|
||||||
[]byte(h.publicKey.HexString()),
|
[]byte(MachinePublicKeyStripPrefix(*h.publicKey)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,10 +74,10 @@ func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
|
||||||
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||||
body, _ := io.ReadAll(ctx.Request.Body)
|
body, _ := io.ReadAll(ctx.Request.Body)
|
||||||
machineKeyStr := ctx.Param("id")
|
machineKeyStr := ctx.Param("id")
|
||||||
machineKey, err := wgkey.ParseHex(machineKeyStr)
|
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot parse machine key")
|
Msg("Cannot parse machine key")
|
||||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||||
|
@ -88,7 +89,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||||
err = decode(body, &req, &machineKey, h.privateKey)
|
err = decode(body, &req, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot decode message")
|
Msg("Cannot decode message")
|
||||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||||
|
@ -98,17 +99,17 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
machine, err := h.GetMachineByMachineKey(machineKey.HexString())
|
machine, err := h.GetMachineByMachineKey(machineKey)
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
||||||
newMachine := Machine{
|
newMachine := Machine{
|
||||||
Expiry: &time.Time{},
|
Expiry: &time.Time{},
|
||||||
MachineKey: machineKey.HexString(),
|
MachineKey: MachinePublicKeyStripPrefix(machineKey),
|
||||||
Name: req.Hostinfo.Hostname,
|
Name: req.Hostinfo.Hostname,
|
||||||
}
|
}
|
||||||
if err := h.db.Create(&newMachine).Error; err != nil {
|
if err := h.db.Create(&newMachine).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not create row")
|
Msg("Could not create row")
|
||||||
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
|
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
|
||||||
|
@ -125,7 +126,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||||
// - Trying to log out (sending a expiry in the past)
|
// - Trying to log out (sending a expiry in the past)
|
||||||
// - A valid, registered machine, looking for the node map
|
// - A valid, registered machine, looking for the node map
|
||||||
// - Expired machine wanting to reauthenticate
|
// - Expired machine wanting to reauthenticate
|
||||||
if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
if machine.NodeKey == req.NodeKey.String() {
|
||||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
||||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||||
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
|
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
|
||||||
|
@ -144,7 +145,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||||
if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
|
if machine.NodeKey == req.OldNodeKey.String() &&
|
||||||
!machine.isExpired() {
|
!machine.isExpired() {
|
||||||
h.handleMachineRefreshKey(ctx, machineKey, req, *machine)
|
h.handleMachineRefreshKey(ctx, machineKey, req, *machine)
|
||||||
|
|
||||||
|
@ -168,7 +169,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getMapResponse(
|
func (h *Headscale) getMapResponse(
|
||||||
machineKey wgkey.Key,
|
machineKey key.MachinePublic,
|
||||||
req tailcfg.MapRequest,
|
req tailcfg.MapRequest,
|
||||||
machine *Machine,
|
machine *Machine,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
|
@ -179,6 +180,7 @@ func (h *Headscale) getMapResponse(
|
||||||
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
Caller().
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot convert to node")
|
Msg("Cannot convert to node")
|
||||||
|
@ -189,6 +191,7 @@ func (h *Headscale) getMapResponse(
|
||||||
peers, err := h.getValidPeers(machine)
|
peers, err := h.getValidPeers(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
Caller().
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot fetch peers")
|
Msg("Cannot fetch peers")
|
||||||
|
@ -201,6 +204,7 @@ func (h *Headscale) getMapResponse(
|
||||||
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
Caller().
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed to convert peers to Tailscale nodes")
|
Msg("Failed to convert peers to Tailscale nodes")
|
||||||
|
@ -238,10 +242,7 @@ func (h *Headscale) getMapResponse(
|
||||||
|
|
||||||
encoder, _ := zstd.NewWriter(nil)
|
encoder, _ := zstd.NewWriter(nil)
|
||||||
srcCompressed := encoder.EncodeAll(src, nil)
|
srcCompressed := encoder.EncodeAll(src, nil)
|
||||||
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
respBody = h.privateKey.SealTo(machineKey, srcCompressed)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
respBody, err = encode(resp, &machineKey, h.privateKey)
|
respBody, err = encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -257,7 +258,7 @@ func (h *Headscale) getMapResponse(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getMapKeepAliveResponse(
|
func (h *Headscale) getMapKeepAliveResponse(
|
||||||
machineKey wgkey.Key,
|
machineKey key.MachinePublic,
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
mapResponse := tailcfg.MapResponse{
|
mapResponse := tailcfg.MapResponse{
|
||||||
|
@ -269,10 +270,7 @@ func (h *Headscale) getMapKeepAliveResponse(
|
||||||
src, _ := json.Marshal(mapResponse)
|
src, _ := json.Marshal(mapResponse)
|
||||||
encoder, _ := zstd.NewWriter(nil)
|
encoder, _ := zstd.NewWriter(nil)
|
||||||
srcCompressed := encoder.EncodeAll(src, nil)
|
srcCompressed := encoder.EncodeAll(src, nil)
|
||||||
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
respBody = h.privateKey.SealTo(machineKey, srcCompressed)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
|
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -288,13 +286,12 @@ func (h *Headscale) getMapKeepAliveResponse(
|
||||||
|
|
||||||
func (h *Headscale) handleMachineLogOut(
|
func (h *Headscale) handleMachineLogOut(
|
||||||
ctx *gin.Context,
|
ctx *gin.Context,
|
||||||
machineKey wgkey.Key,
|
machineKey key.MachinePublic,
|
||||||
machine Machine,
|
machine Machine,
|
||||||
) {
|
) {
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("handler", "Registration").
|
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Client requested logout")
|
Msg("Client requested logout")
|
||||||
|
|
||||||
|
@ -306,7 +303,7 @@ func (h *Headscale) handleMachineLogOut(
|
||||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
ctx.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
|
@ -318,14 +315,13 @@ func (h *Headscale) handleMachineLogOut(
|
||||||
|
|
||||||
func (h *Headscale) handleMachineValidRegistration(
|
func (h *Headscale) handleMachineValidRegistration(
|
||||||
ctx *gin.Context,
|
ctx *gin.Context,
|
||||||
machineKey wgkey.Key,
|
machineKey key.MachinePublic,
|
||||||
machine Machine,
|
machine Machine,
|
||||||
) {
|
) {
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
|
|
||||||
// The machine registration is valid, respond with redirect to /map
|
// The machine registration is valid, respond with redirect to /map
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Client is registered and we have the current NodeKey. All clear to /map")
|
Msg("Client is registered and we have the current NodeKey. All clear to /map")
|
||||||
|
|
||||||
|
@ -337,7 +333,7 @@ func (h *Headscale) handleMachineValidRegistration(
|
||||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
|
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
|
||||||
|
@ -353,7 +349,7 @@ func (h *Headscale) handleMachineValidRegistration(
|
||||||
|
|
||||||
func (h *Headscale) handleMachineExpired(
|
func (h *Headscale) handleMachineExpired(
|
||||||
ctx *gin.Context,
|
ctx *gin.Context,
|
||||||
machineKey wgkey.Key,
|
machineKey key.MachinePublic,
|
||||||
registerRequest tailcfg.RegisterRequest,
|
registerRequest tailcfg.RegisterRequest,
|
||||||
machine Machine,
|
machine Machine,
|
||||||
) {
|
) {
|
||||||
|
@ -361,7 +357,6 @@ func (h *Headscale) handleMachineExpired(
|
||||||
|
|
||||||
// The client has registered before, but has expired
|
// The client has registered before, but has expired
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Machine registration has expired. Sending a authurl to register")
|
Msg("Machine registration has expired. Sending a authurl to register")
|
||||||
|
|
||||||
|
@ -373,16 +368,16 @@ func (h *Headscale) handleMachineExpired(
|
||||||
|
|
||||||
if h.cfg.OIDC.Issuer != "" {
|
if h.cfg.OIDC.Issuer != "" {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String())
|
||||||
} else {
|
} else {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name).
|
machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name).
|
||||||
|
@ -398,17 +393,16 @@ func (h *Headscale) handleMachineExpired(
|
||||||
|
|
||||||
func (h *Headscale) handleMachineRefreshKey(
|
func (h *Headscale) handleMachineRefreshKey(
|
||||||
ctx *gin.Context,
|
ctx *gin.Context,
|
||||||
machineKey wgkey.Key,
|
machineKey key.MachinePublic,
|
||||||
registerRequest tailcfg.RegisterRequest,
|
registerRequest tailcfg.RegisterRequest,
|
||||||
machine Machine,
|
machine Machine,
|
||||||
) {
|
) {
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||||
machine.NodeKey = wgkey.Key(registerRequest.NodeKey).HexString()
|
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||||
h.db.Save(&machine)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
|
@ -416,7 +410,7 @@ func (h *Headscale) handleMachineRefreshKey(
|
||||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
||||||
|
@ -428,7 +422,7 @@ func (h *Headscale) handleMachineRefreshKey(
|
||||||
|
|
||||||
func (h *Headscale) handleMachineRegistrationNew(
|
func (h *Headscale) handleMachineRegistrationNew(
|
||||||
ctx *gin.Context,
|
ctx *gin.Context,
|
||||||
machineKey wgkey.Key,
|
machineKey key.MachinePublic,
|
||||||
registerRequest tailcfg.RegisterRequest,
|
registerRequest tailcfg.RegisterRequest,
|
||||||
machine Machine,
|
machine Machine,
|
||||||
) {
|
) {
|
||||||
|
@ -436,18 +430,17 @@ func (h *Headscale) handleMachineRegistrationNew(
|
||||||
|
|
||||||
// The machine registration is new, redirect the client to the registration URL
|
// The machine registration is new, redirect the client to the registration URL
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("The node is sending us a new NodeKey, sending auth url")
|
Msg("The node is sending us a new NodeKey, sending auth url")
|
||||||
if h.cfg.OIDC.Issuer != "" {
|
if h.cfg.OIDC.Issuer != "" {
|
||||||
resp.AuthURL = fmt.Sprintf(
|
resp.AuthURL = fmt.Sprintf(
|
||||||
"%s/oidc/register/%s",
|
"%s/oidc/register/%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||||
machineKey.HexString(),
|
machineKey.String(),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !registerRequest.Expiry.IsZero() {
|
if !registerRequest.Expiry.IsZero() {
|
||||||
|
@ -457,19 +450,21 @@ func (h *Headscale) handleMachineRegistrationNew(
|
||||||
Time("expiry", registerRequest.Expiry).
|
Time("expiry", registerRequest.Expiry).
|
||||||
Msg("Non-zero expiry time requested, adding to cache")
|
Msg("Non-zero expiry time requested, adding to cache")
|
||||||
h.requestedExpiryCache.Set(
|
h.requestedExpiryCache.Set(
|
||||||
machineKey.HexString(),
|
machineKey.String(),
|
||||||
registerRequest.Expiry,
|
registerRequest.Expiry,
|
||||||
requestedExpiryCacheExpiration,
|
requestedExpiryCacheExpiration,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
machine.NodeKey = wgkey.Key(registerRequest.NodeKey).HexString() // save the NodeKey
|
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||||
|
|
||||||
|
// save the NodeKey
|
||||||
h.db.Save(&machine)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
ctx.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
|
@ -481,7 +476,7 @@ func (h *Headscale) handleMachineRegistrationNew(
|
||||||
|
|
||||||
func (h *Headscale) handleAuthKey(
|
func (h *Headscale) handleAuthKey(
|
||||||
ctx *gin.Context,
|
ctx *gin.Context,
|
||||||
machineKey wgkey.Key,
|
machineKey key.MachinePublic,
|
||||||
registerRequest tailcfg.RegisterRequest,
|
registerRequest tailcfg.RegisterRequest,
|
||||||
machine Machine,
|
machine Machine,
|
||||||
) {
|
) {
|
||||||
|
@ -493,6 +488,7 @@ func (h *Headscale) handleAuthKey(
|
||||||
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
|
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
Caller().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -501,6 +497,7 @@ func (h *Headscale) handleAuthKey(
|
||||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
Caller().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -513,6 +510,7 @@ func (h *Headscale) handleAuthKey(
|
||||||
}
|
}
|
||||||
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
|
||||||
log.Error().
|
log.Error().
|
||||||
|
Caller().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Failed authentication via AuthKey")
|
Msg("Failed authentication via AuthKey")
|
||||||
|
@ -537,6 +535,7 @@ func (h *Headscale) handleAuthKey(
|
||||||
ip, err := h.getAvailableIP()
|
ip, err := h.getAvailableIP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
Caller().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Failed to find an available IP")
|
Msg("Failed to find an available IP")
|
||||||
|
@ -555,9 +554,9 @@ func (h *Headscale) handleAuthKey(
|
||||||
machine.AuthKeyID = uint(pak.ID)
|
machine.AuthKeyID = uint(pak.ID)
|
||||||
machine.IPAddress = ip.String()
|
machine.IPAddress = ip.String()
|
||||||
machine.NamespaceID = pak.NamespaceID
|
machine.NamespaceID = pak.NamespaceID
|
||||||
machine.NodeKey = wgkey.Key(registerRequest.NodeKey).
|
|
||||||
HexString()
|
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||||
// we update it just in case
|
// we update it just in case
|
||||||
machine.Registered = true
|
machine.Registered = true
|
||||||
machine.RegisterMethod = RegisterMethodAuthKey
|
machine.RegisterMethod = RegisterMethodAuthKey
|
||||||
h.db.Save(&machine)
|
h.db.Save(&machine)
|
||||||
|
@ -571,6 +570,7 @@ func (h *Headscale) handleAuthKey(
|
||||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
Caller().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Err(err).
|
Err(err).
|
||||||
|
|
|
@ -11,9 +11,9 @@ import (
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/pterm/pterm"
|
"github.com/pterm/pterm"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"go4.org/mem"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/wgkey"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -486,11 +486,10 @@ func nodesToPtables(
|
||||||
expiry = machine.Expiry.AsTime()
|
expiry = machine.Expiry.AsTime()
|
||||||
}
|
}
|
||||||
|
|
||||||
nKey, err := wgkey.ParseHex(machine.NodeKey)
|
nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
nodeKey := tailcfg.NodeKey(nKey)
|
|
||||||
|
|
||||||
var online string
|
var online string
|
||||||
if lastSeen.After(
|
if lastSeen.After(
|
||||||
|
|
66
machine.go
66
machine.go
|
@ -12,12 +12,12 @@ import (
|
||||||
"github.com/fatih/set"
|
"github.com/fatih/set"
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"go4.org/mem"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
"gorm.io/datatypes"
|
"gorm.io/datatypes"
|
||||||
"gorm.io/gorm"
|
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/wgkey"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -260,9 +260,11 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
|
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
|
||||||
func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
|
func (h *Headscale) GetMachineByMachineKey(
|
||||||
|
machineKey key.MachinePublic,
|
||||||
|
) (*Machine, error) {
|
||||||
m := Machine{}
|
m := Machine{}
|
||||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
|
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", MachinePublicKeyStripPrefix(machineKey)); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -437,25 +439,31 @@ func (machine Machine) toNode(
|
||||||
dnsConfig *tailcfg.DNSConfig,
|
dnsConfig *tailcfg.DNSConfig,
|
||||||
includeRoutes bool,
|
includeRoutes bool,
|
||||||
) (*tailcfg.Node, error) {
|
) (*tailcfg.Node, error) {
|
||||||
nodeKey, err := wgkey.ParseHex(machine.NodeKey)
|
nodeKey, err := key.ParseNodePublicUntyped(mem.S(machine.NodeKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
log.Trace().
|
||||||
|
Caller().
|
||||||
|
Str("node_key", machine.NodeKey).
|
||||||
|
Msgf("Failed to parse node public key from hex")
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to parse node public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
machineKey, err := wgkey.ParseHex(machine.MachineKey)
|
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machine.MachineKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var discoKey tailcfg.DiscoKey
|
var discoKey key.DiscoPublic
|
||||||
if machine.DiscoKey != "" {
|
if machine.DiscoKey != "" {
|
||||||
dKey, err := wgkey.ParseHex(machine.DiscoKey)
|
dKey := key.DiscoPublic{}
|
||||||
|
err := dKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to parse disco public key: %w", err)
|
||||||
}
|
}
|
||||||
discoKey = tailcfg.DiscoKey(dKey)
|
discoKey = key.DiscoPublic(dKey)
|
||||||
} else {
|
} else {
|
||||||
discoKey = tailcfg.DiscoKey{}
|
discoKey = key.DiscoPublic{}
|
||||||
}
|
}
|
||||||
|
|
||||||
addrs := []netaddr.IPPrefix{}
|
addrs := []netaddr.IPPrefix{}
|
||||||
|
@ -555,9 +563,9 @@ func (machine Machine) toNode(
|
||||||
), // in headscale, unlike tailcontrol server, IDs are permanent
|
), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||||
Name: hostname,
|
Name: hostname,
|
||||||
User: tailcfg.UserID(machine.NamespaceID),
|
User: tailcfg.UserID(machine.NamespaceID),
|
||||||
Key: tailcfg.NodeKey(nodeKey),
|
Key: nodeKey,
|
||||||
KeyExpiry: keyExpiry,
|
KeyExpiry: keyExpiry,
|
||||||
Machine: tailcfg.MachineKey(machineKey),
|
Machine: machineKey,
|
||||||
DiscoKey: discoKey,
|
DiscoKey: discoKey,
|
||||||
Addresses: addrs,
|
Addresses: addrs,
|
||||||
AllowedIPs: allowedIPs,
|
AllowedIPs: allowedIPs,
|
||||||
|
@ -618,31 +626,35 @@ func (machine *Machine) toProto() *v1.Machine {
|
||||||
|
|
||||||
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
||||||
func (h *Headscale) RegisterMachine(
|
func (h *Headscale) RegisterMachine(
|
||||||
key string,
|
machineKeyStr string,
|
||||||
namespaceName string,
|
namespaceName string,
|
||||||
) (*Machine, error) {
|
) (*Machine, error) {
|
||||||
namespace, err := h.GetNamespace(namespaceName)
|
namespace, err := h.GetNamespace(namespaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
machineKey, err := wgkey.ParseHex(key)
|
|
||||||
|
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
machine := Machine{}
|
log.Trace().
|
||||||
if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
|
Caller().
|
||||||
result.Error,
|
Str("machine_key_str", machineKeyStr).
|
||||||
gorm.ErrRecordNotFound,
|
Str("machine_key", machineKey.String()).
|
||||||
) {
|
Msg("Registering machine")
|
||||||
return nil, errMachineNotFound
|
|
||||||
|
machine, err := h.GetMachineByMachineKey(machineKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
|
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
|
||||||
// This means that if a user is to slow with register a machine, it will possibly not
|
// This means that if a user is to slow with register a machine, it will possibly not
|
||||||
// have the correct expiry.
|
// have the correct expiry.
|
||||||
requestedTime := time.Time{}
|
requestedTime := time.Time{}
|
||||||
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.HexString()); found {
|
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
|
@ -658,9 +670,9 @@ func (h *Headscale) RegisterMachine(
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("machine already registered, reauthenticating")
|
Msg("machine already registered, reauthenticating")
|
||||||
|
|
||||||
h.RefreshMachine(&machine, requestedTime)
|
h.RefreshMachine(machine, requestedTime)
|
||||||
|
|
||||||
return &machine, nil
|
return machine, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
|
@ -709,7 +721,7 @@ func (h *Headscale) RegisterMachine(
|
||||||
Str("ip", ip.String()).
|
Str("ip", ip.String()).
|
||||||
Msg("Machine registered with the database")
|
Msg("Machine registered with the database")
|
||||||
|
|
||||||
return &machine, nil
|
return machine, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
||||||
|
|
16
oidc.go
16
oidc.go
|
@ -15,8 +15,10 @@ import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"go4.org/mem"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -187,7 +189,17 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
machineKey, machineKeyOK := machineKeyIf.(string)
|
|
||||||
|
machineKeyStr, machineKeyOK := machineKeyIf.(string)
|
||||||
|
|
||||||
|
machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Msg("could not parse machine public key")
|
||||||
|
ctx.String(http.StatusBadRequest, "could not parse public key")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if !machineKeyOK {
|
if !machineKeyOK {
|
||||||
log.Error().Msg("could not get machine key from cache")
|
log.Error().Msg("could not get machine key from cache")
|
||||||
|
@ -201,7 +213,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
|
|
||||||
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
|
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
|
||||||
requestedTime := time.Time{}
|
requestedTime := time.Time{}
|
||||||
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey); found {
|
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found {
|
||||||
if reqTime, ok := requestedTimeIf.(time.Time); ok {
|
if reqTime, ok := requestedTimeIf.(time.Time); ok {
|
||||||
requestedTime = reqTime
|
requestedTime = reqTime
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/wgkey"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
|
func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
|
||||||
|
@ -19,8 +19,8 @@ func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
|
||||||
dbString string
|
dbString string
|
||||||
dbType string
|
dbType string
|
||||||
dbDebug bool
|
dbDebug bool
|
||||||
publicKey *wgkey.Key
|
publicKey *key.MachinePublic
|
||||||
privateKey *wgkey.Private
|
privateKey *key.MachinePrivate
|
||||||
aclPolicy *ACLPolicy
|
aclPolicy *ACLPolicy
|
||||||
aclRules []tailcfg.FilterRule
|
aclRules []tailcfg.FilterRule
|
||||||
lastStateChange sync.Map
|
lastStateChange sync.Map
|
||||||
|
|
17
poll.go
17
poll.go
|
@ -9,10 +9,11 @@ import (
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"go4.org/mem"
|
||||||
"gorm.io/datatypes"
|
"gorm.io/datatypes"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/wgkey"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -36,7 +37,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
|
||||||
Msg("PollNetMapHandler called")
|
Msg("PollNetMapHandler called")
|
||||||
body, _ := io.ReadAll(ctx.Request.Body)
|
body, _ := io.ReadAll(ctx.Request.Body)
|
||||||
mKeyStr := ctx.Param("id")
|
mKeyStr := ctx.Param("id")
|
||||||
mKey, err := wgkey.ParseHex(mKeyStr)
|
mKey, err := key.ParseMachinePublicUntyped(mem.S(mKeyStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
|
@ -58,19 +59,19 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
machine, err := h.GetMachineByMachineKey(mKey.HexString())
|
machine, err := h.GetMachineByMachineKey(mKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
|
Msgf("Ignoring request, cannot find machine with key %s", mKey.String())
|
||||||
ctx.String(http.StatusUnauthorized, "")
|
ctx.String(http.StatusUnauthorized, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString())
|
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.String())
|
||||||
ctx.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
}
|
}
|
||||||
log.Trace().
|
log.Trace().
|
||||||
|
@ -82,7 +83,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
|
||||||
hostinfo, _ := json.Marshal(req.Hostinfo)
|
hostinfo, _ := json.Marshal(req.Hostinfo)
|
||||||
machine.Name = req.Hostinfo.Hostname
|
machine.Name = req.Hostinfo.Hostname
|
||||||
machine.HostInfo = datatypes.JSON(hostinfo)
|
machine.HostInfo = datatypes.JSON(hostinfo)
|
||||||
machine.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
|
machine.DiscoKey = DiscoPublicKeyStripPrefix(req.DiscoKey)
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
// From Tailscale client:
|
// From Tailscale client:
|
||||||
|
@ -225,7 +226,7 @@ func (h *Headscale) PollNetMapStream(
|
||||||
ctx *gin.Context,
|
ctx *gin.Context,
|
||||||
machine *Machine,
|
machine *Machine,
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
machineKey wgkey.Key,
|
machineKey key.MachinePublic,
|
||||||
pollDataChan chan []byte,
|
pollDataChan chan []byte,
|
||||||
keepAliveChan chan []byte,
|
keepAliveChan chan []byte,
|
||||||
updateChan chan struct{},
|
updateChan chan struct{},
|
||||||
|
@ -491,7 +492,7 @@ func (h *Headscale) scheduledPollWorker(
|
||||||
cancelChan <-chan struct{},
|
cancelChan <-chan struct{},
|
||||||
updateChan chan<- struct{},
|
updateChan chan<- struct{},
|
||||||
keepAliveChan chan<- []byte,
|
keepAliveChan chan<- []byte,
|
||||||
machineKey wgkey.Key,
|
machineKey key.MachinePublic,
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
machine *Machine,
|
machine *Machine,
|
||||||
) {
|
) {
|
||||||
|
|
123
utils.go
123
utils.go
|
@ -7,50 +7,95 @@ package headscale
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/crypto/nacl/box"
|
"github.com/rs/zerolog/log"
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/wgkey"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
errCannotDecryptReponse = Error("cannot decrypt response")
|
errCannotDecryptReponse = Error("cannot decrypt response")
|
||||||
errResponseMissingNonce = Error("response missing nonce")
|
errResponseMissingNonce = Error("response missing nonce")
|
||||||
errCouldNotAllocateIP = Error("could not find any suitable IP")
|
errCouldNotAllocateIP = Error("could not find any suitable IP")
|
||||||
|
|
||||||
|
// These constants are copied from the upstream tailscale.com/types/key
|
||||||
|
// library, because they are not exported.
|
||||||
|
// https://github.com/tailscale/tailscale/tree/main/types/key
|
||||||
|
|
||||||
|
// nodePrivateHexPrefix is the prefix used to identify a
|
||||||
|
// hex-encoded node private key.
|
||||||
|
//
|
||||||
|
// This prefix name is a little unfortunate, in that it comes from
|
||||||
|
// WireGuard's own key types, and we've used it for both key types
|
||||||
|
// we persist to disk (machine and node keys). But we're stuck
|
||||||
|
// with it for now, barring another round of tricky migration.
|
||||||
|
nodePrivateHexPrefix = "privkey:"
|
||||||
|
|
||||||
|
// nodePublicHexPrefix is the prefix used to identify a
|
||||||
|
// hex-encoded node public key.
|
||||||
|
//
|
||||||
|
// This prefix is used in the control protocol, so cannot be
|
||||||
|
// changed.
|
||||||
|
nodePublicHexPrefix = "nodekey:"
|
||||||
|
|
||||||
|
// machinePrivateHexPrefix is the prefix used to identify a
|
||||||
|
// hex-encoded machine private key.
|
||||||
|
//
|
||||||
|
// This prefix name is a little unfortunate, in that it comes from
|
||||||
|
// WireGuard's own key types. Unfortunately we're stuck with it for
|
||||||
|
// machine keys, because we serialize them to disk with this prefix.
|
||||||
|
machinePrivateHexPrefix = "privkey:"
|
||||||
|
|
||||||
|
// machinePublicHexPrefix is the prefix used to identify a
|
||||||
|
// hex-encoded machine public key.
|
||||||
|
//
|
||||||
|
// This prefix is used in the control protocol, so cannot be
|
||||||
|
// changed.
|
||||||
|
machinePublicHexPrefix = "mkey:"
|
||||||
|
|
||||||
|
// discoPublicHexPrefix is the prefix used to identify a
|
||||||
|
// hex-encoded disco public key.
|
||||||
|
//
|
||||||
|
// This prefix is used in the control protocol, so cannot be
|
||||||
|
// changed.
|
||||||
|
discoPublicHexPrefix = "discokey:"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string {
|
||||||
|
return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string {
|
||||||
|
return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string {
|
||||||
|
return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
|
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
|
||||||
type Error string
|
type Error string
|
||||||
|
|
||||||
func (e Error) Error() string { return string(e) }
|
func (e Error) Error() string { return string(e) }
|
||||||
|
|
||||||
func decode(
|
func decode(
|
||||||
msg []byte,
|
|
||||||
v interface{},
|
|
||||||
pubKey *wgkey.Key,
|
|
||||||
privKey *wgkey.Private,
|
|
||||||
) error {
|
|
||||||
return decodeMsg(msg, v, pubKey, privKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeMsg(
|
|
||||||
msg []byte,
|
msg []byte,
|
||||||
output interface{},
|
output interface{},
|
||||||
pubKey *wgkey.Key,
|
pubKey *key.MachinePublic,
|
||||||
privKey *wgkey.Private,
|
privKey *key.MachinePrivate,
|
||||||
) error {
|
) error {
|
||||||
decrypted, err := decryptMsg(msg, pubKey, privKey)
|
log.Trace().Int("length", len(msg)).Msg("Trying to decrypt")
|
||||||
if err != nil {
|
|
||||||
return err
|
decrypted, ok := privKey.OpenFrom(*pubKey, msg)
|
||||||
|
if !ok {
|
||||||
|
return errCannotDecryptReponse
|
||||||
}
|
}
|
||||||
// fmt.Println(string(decrypted))
|
|
||||||
if err := json.Unmarshal(decrypted, output); err != nil {
|
if err := json.Unmarshal(decrypted, output); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -58,45 +103,17 @@ func decodeMsg(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
|
func encode(
|
||||||
var nonce [24]byte
|
v interface{},
|
||||||
if len(msg) < len(nonce)+1 {
|
pubKey *key.MachinePublic,
|
||||||
return nil, errResponseMissingNonce
|
privKey *key.MachinePrivate,
|
||||||
}
|
) ([]byte, error) {
|
||||||
copy(nonce[:], msg)
|
|
||||||
msg = msg[len(nonce):]
|
|
||||||
|
|
||||||
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
|
|
||||||
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
|
|
||||||
if !ok {
|
|
||||||
return nil, errCannotDecryptReponse
|
|
||||||
}
|
|
||||||
|
|
||||||
return decrypted, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
|
|
||||||
b, err := json.Marshal(v)
|
b, err := json.Marshal(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return encodeMsg(b, pubKey, privKey)
|
return privKey.SealTo(*pubKey, b), nil
|
||||||
}
|
|
||||||
|
|
||||||
func encodeMsg(
|
|
||||||
payload []byte,
|
|
||||||
pubKey *wgkey.Key,
|
|
||||||
privKey *wgkey.Private,
|
|
||||||
) ([]byte, error) {
|
|
||||||
var nonce [24]byte
|
|
||||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
|
|
||||||
msg := box.Seal(nonce[:], payload, &nonce, pub, pri)
|
|
||||||
|
|
||||||
return msg, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
|
func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
|
||||||
|
|
Loading…
Reference in a new issue