From 74e6c1479e64ea13e49fbb4ca87f668dd14068ab Mon Sep 17 00:00:00 2001 From: Raal Goff Date: Sun, 10 Oct 2021 17:22:42 +0800 Subject: [PATCH] updates from code review --- api.go | 71 +++++++++++++------------------------- app.go | 4 +-- cli.go | 3 ++ cmd/headscale/cli/utils.go | 18 +++++----- go.mod | 6 ++-- machine.go | 30 ++++++++++++++-- oidc.go | 43 ++++++++++------------- 7 files changed, 88 insertions(+), 87 deletions(-) diff --git a/api.go b/api.go index a70df5b..bda9d9b 100644 --- a/api.go +++ b/api.go @@ -65,7 +65,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Err(err). Msg("Cannot parse machine key") - machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc() + machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() c.String(http.StatusInternalServerError, "Sad!") return } @@ -76,34 +76,33 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Err(err). Msg("Cannot decode message") - machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc() + machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() c.String(http.StatusInternalServerError, "Very sad!") return } now := time.Now().UTC() - var m Machine - if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { + m, err := h.GetMachineByMachineKey(mKey.HexString()) + if errors.Is(err, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") - m = Machine{ - Expiry: &time.Time{}, - MachineKey: mKey.HexString(), - Name: req.Hostinfo.Hostname, - NodeKey: wgkey.Key(req.NodeKey).HexString(), - LastSuccessfulUpdate: &now, + newMachine := Machine{ + Expiry: &time.Time{}, + MachineKey: mKey.HexString(), + Name: req.Hostinfo.Hostname, } - if err := h.db.Create(&m).Error; err != nil { + if err := h.db.Create(&newMachine).Error; err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Could not create row") - machineRegistrations.WithLabelValues("unkown", "web", "error", m.Namespace.Name).Inc() + machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc() return } + m = &newMachine } if !m.Registered && req.Auth.AuthKey != "" { - h.handleAuthKey(c, h.db, mKey, req, m) + h.handleAuthKey(c, h.db, mKey, req, *m) return } @@ -112,13 +111,14 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // We have the updated key! if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { + // The client sends an Expiry in the past if the client is requesting a logout if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { - log.Debug(). + log.Info(). Str("handler", "Registration"). Str("machine", m.Name). Msg("Client requested logout") - m.Expiry = &req.Expiry + m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired h.db.Save(&m) resp.AuthURL = "" @@ -138,6 +138,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { } if m.Registered && m.Expiry.UTC().After(now) { + // The machine registration is valid, respond with redirect to /map log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). @@ -161,10 +162,11 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } + // The client has registered before, but has expired log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). - Msg("Not registered (or expired) and not NodeKey rotation. Sending a authurl to register") + Msg("Machine registration has expired. Sending a authurl to register") if h.cfg.OIDCIssuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", @@ -174,7 +176,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } - m.Expiry = &req.Expiry // save the requested expiry time for retrieval later + m.RequestedExpiry = &req.Expiry // save the requested expiry time for retrieval later in the authentication flow h.db.Save(&m) respBody, err := encode(resp, &mKey, h.privateKey) @@ -216,34 +218,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } - // We arrive here after a client is restarted without finalizing the authentication flow or - // when headscale is stopped in the middle of the auth process. - if m.Registered && m.Expiry.UTC().After(now) { - log.Debug(). - Str("handler", "Registration"). - Str("machine", m.Name). - Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map") - - m.NodeKey = wgkey.Key(req.NodeKey).HexString() - h.db.Save(&m) - - resp.AuthURL = "" - resp.MachineAuthorized = true - resp.User = *m.Namespace.toUser() - - respBody, err := encode(resp, &mKey, h.privateKey) - if err != nil { - log.Error(). - Str("handler", "Registration"). - Err(err). - Msg("Cannot encode message") - c.String(http.StatusInternalServerError, "") - return - } - c.Data(200, "application/json; charset=utf-8", respBody) - return - } - + // The machine registration is new, redirect the client to the registration URL log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). @@ -255,8 +230,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } - m.Expiry = &req.Expiry // save the requested expiry time for retrieval later - m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the new nodekey + m.RequestedExpiry = &req.Expiry // save the requested expiry time for retrieval later in the authentication flow + m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey h.db.Save(&m) respBody, err := encode(resp, &mKey, h.privateKey) @@ -436,6 +411,8 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, m.RegisterMethod = "authKey" db.Save(&m) + h.updateMachineExpiry(&m) // TODO: do we want to do different expiry times for AuthKeys? + resp.MachineAuthorized = true resp.User = *pak.Namespace.toUser() respBody, err := encode(resp, &idKey, h.privateKey) diff --git a/app.go b/app.go index 9e688fe..239998c 100644 --- a/app.go +++ b/app.go @@ -59,8 +59,8 @@ type Config struct { OIDCClientID string OIDCClientSecret string - MaxMachineExpiry time.Duration - DefaultMachineExpiry time.Duration + MaxMachineRegistrationDuration time.Duration + DefaultMachineRegistrationDuration time.Duration } // Headscale represents the base app of the service diff --git a/cli.go b/cli.go index 9c5b66e..8610b33 100644 --- a/cli.go +++ b/cli.go @@ -23,6 +23,8 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err return nil, errors.New("Machine not found") } + h.updateMachineExpiry(&m) // update the machine's expiry before bailing if its already registered + if m.isAlreadyRegistered() { return nil, errors.New("Machine already registered") } @@ -36,5 +38,6 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err m.Registered = true m.RegisterMethod = "cli" h.db.Save(&m) + return &m, nil } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 17bc37e..366e959 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -144,14 +144,16 @@ func getHeadscaleApp() (*headscale.Headscale, error) { return nil, err } - maxMachineExpiry, _ := time.ParseDuration("8h") - if viper.GetDuration("max_machine_expiry") >= time.Second { - maxMachineExpiry = viper.GetDuration("max_machine_expiry") + // maxMachineRegistrationDuration is the maximum time a client can request for a client registration + maxMachineRegistrationDuration, _ := time.ParseDuration("10h") + if viper.GetDuration("max_machine_registration_duration") >= time.Second { + maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration") } - defaultMachineExpiry, _ := time.ParseDuration("8h") - if viper.GetDuration("default_machine_expiry") >= time.Second { - defaultMachineExpiry = viper.GetDuration("default_machine_expiry") + // defaultMachineRegistrationDuration is the default time assigned to a client registration if one is not specified by the client + defaultMachineRegistrationDuration, _ := time.ParseDuration("8h") + if viper.GetDuration("default_machine_registration_duration") >= time.Second { + defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration") } cfg := headscale.Config{ @@ -188,8 +190,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) { OIDCClientID: viper.GetString("oidc_client_id"), OIDCClientSecret: viper.GetString("oidc_client_secret"), - MaxMachineExpiry: maxMachineExpiry, - DefaultMachineExpiry: defaultMachineExpiry, + MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time + DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration } h, err := headscale.NewHeadscale(cfg) diff --git a/go.mod b/go.mod index 7e137e1..5a116bb 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/docker/cli v20.10.8+incompatible // indirect github.com/docker/docker v20.10.8+incompatible // indirect github.com/efekarakus/termcolor v1.0.1 - github.com/fatih/set v0.2.1 // indirect + github.com/fatih/set v0.2.1 github.com/gin-gonic/gin v1.7.4 github.com/gofrs/uuid v4.0.0+incompatible github.com/google/go-github v17.0.0+incompatible // indirect @@ -23,7 +23,7 @@ require ( github.com/opencontainers/runc v1.0.2 // indirect github.com/ory/dockertest/v3 v3.7.0 github.com/patrickmn/go-cache v2.1.0+incompatible - github.com/prometheus/client_golang v1.11.0 // indirect + github.com/prometheus/client_golang v1.11.0 github.com/pterm/pterm v0.12.30 github.com/rs/zerolog v1.25.0 github.com/s12v/go-jwks v0.2.1 @@ -33,7 +33,7 @@ require ( github.com/tailscale/hujson v0.0.0-20210818175511-7360507a6e88 github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect - github.com/zsais/go-gin-prometheus v0.1.0 // indirect + github.com/zsais/go-gin-prometheus v0.1.0 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 golang.org/x/net v0.0.0-20210913180222-943fd674d43e // indirect golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602 diff --git a/machine.go b/machine.go index bd5caf0..6eecbc6 100644 --- a/machine.go +++ b/machine.go @@ -36,6 +36,7 @@ type Machine struct { LastSeen *time.Time LastSuccessfulUpdate *time.Time Expiry *time.Time + RequestedExpiry *time.Time // when a client connects, it may request a specific expiry time, use this field to store it HostInfo datatypes.JSON Endpoints datatypes.JSON @@ -59,8 +60,33 @@ func (m Machine) isAlreadyRegistered() bool { // isExpired returns whether the machine registration has expired func (m Machine) isExpired() bool { return time.Now().UTC().After(*m.Expiry) -} - +} + +// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration, +// or the default duration if no Expiry time was requested by the client +func (h *Headscale) updateMachineExpiry(m *Machine) { + + if m.isExpired() { + now := time.Now().UTC() + maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry + defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry + + // clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied + if maxExpiry.Before(*m.RequestedExpiry) { + log.Debug().Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration) + m.Expiry = &maxExpiry + } else if m.RequestedExpiry.IsZero() { + log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration) + m.Expiry = &defaultExpiry + } else { + log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry) + m.Expiry = m.RequestedExpiry + } + + h.db.Save(&m) + } +} + func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { log.Trace(). Str("func", "getDirectPeers"). diff --git a/oidc.go b/oidc.go index 1220098..01c54b4 100644 --- a/oidc.go +++ b/oidc.go @@ -4,14 +4,12 @@ import ( "context" "crypto/rand" "encoding/hex" - "errors" "fmt" "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "golang.org/x/oauth2" - "gorm.io/gorm" "net/http" "strings" "time" @@ -103,6 +101,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } + log.Debug().Msgf("AccessToken: %v", oauth2Token.AccessToken) + rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) if !rawIDTokenOK { c.String(http.StatusBadRequest, "Could not extract ID Token") @@ -117,16 +117,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } + // TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc) //userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token)) //if err != nil { - // c.String(http.StatusBadRequest, "Failed to retrieve userinfo: "+err.Error()) + // c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo: %s", err)) // return //} // Extract custom claims var claims IDTokenClaims if err = idToken.Claims(&claims); err != nil { - c.String(http.StatusBadRequest, "Failed to decode id token claims: "+err.Error()) + c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err)) return } @@ -134,39 +135,44 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { mKeyIf, mKeyFound := h.oidcStateCache.Get(state) if !mKeyFound { + log.Error().Msg("requested machine state key expired before authorisation completed") c.String(http.StatusBadRequest, "state has expired") return } mKeyStr, mKeyOK := mKeyIf.(string) if !mKeyOK { + log.Error().Msg("could not get machine key from cache") c.String(http.StatusInternalServerError, "could not get machine key from cache") return } // retrieve machine information - var m Machine - if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKeyStr); errors.Is(result.Error, gorm.ErrRecordNotFound) { + m, err := h.GetMachineByMachineKey(mKeyStr) + + if err != nil { log.Error().Msg("machine key not found in database") c.String(http.StatusInternalServerError, "could not get machine info from database") return } - //look for a namespace of the users email for now + now := time.Now().UTC() + + // register the machine if it's new if !m.Registered { + nsName := strings.ReplaceAll(claims.Email, "@", "-") // TODO: Implement a better email sanitisation log.Debug().Msg("Registering new machine after successful callback") - ns, err := h.GetNamespace(claims.Email) + ns, err := h.GetNamespace(nsName) if err != nil { - ns, err = h.CreateNamespace(claims.Email) + ns, err = h.CreateNamespace(nsName) if err != nil { log.Error().Msgf("could not create new namespace '%s'", claims.Email) c.String(http.StatusInternalServerError, "could not create new namespace") return } - } ip, err := h.getAvailableIP() @@ -179,24 +185,11 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { m.NamespaceID = ns.ID m.Registered = true m.RegisterMethod = "oidc" + m.LastSuccessfulUpdate = &now h.db.Save(&m) } - if m.isExpired() { - maxExpiry := time.Now().UTC().Add(h.cfg.MaxMachineExpiry) - - // use the maximum expiry if it's sooner than the requested expiry - if maxExpiry.Before(*m.Expiry) { - log.Debug().Msgf("Clamping expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry) - m.Expiry = &maxExpiry - h.db.Save(&m) - } else if m.Expiry.IsZero() { - log.Debug().Msgf("Using default machine expiry time: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry) - defaultExpiry := time.Now().UTC().Add(h.cfg.DefaultMachineExpiry) - m.Expiry = &defaultExpiry - h.db.Save(&m) - } - } + h.updateMachineExpiry(m) c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`