diff --git a/api.go b/api.go index ddabd93..02f2891 100644 --- a/api.go +++ b/api.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" "github.com/rs/zerolog/log" @@ -83,7 +84,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") m = Machine{ - Expiry: &req.Expiry, + Expiry: &time.Time{}, MachineKey: mKey.HexString(), Name: req.Hostinfo.Hostname, NodeKey: wgkey.Key(req.NodeKey).HexString(), @@ -107,7 +108,33 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // We have the updated key! if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { - if m.Registered { + + if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { + log.Debug(). + Str("handler", "Registration"). + Str("machine", m.Name). + Msg("Client requested logout") + + m.Expiry = &req.Expiry + h.db.Save(&m) + + resp.AuthURL = "" + resp.MachineAuthorized = false + resp.User = *m.Namespace.toUser() + respBody, err := encode(resp, &mKey, h.privateKey) + if err != nil { + log.Error(). + Str("handler", "Registration"). + Err(err). + Msg("Cannot encode message") + c.String(http.StatusInternalServerError, "") + return + } + c.Data(200, "application/json; charset=utf-8", respBody) + return + } + + if m.Registered && m.Expiry.UTC().After(now) { log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). @@ -132,14 +159,19 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). - Msg("Not registered and not NodeKey rotation. Sending a authurl to register") + Msg("Not registered (or expired) and not NodeKey rotation. Sending a authurl to register") if h.cfg.OIDCIssuer != "" { - resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", + strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } + + m.Expiry = &req.Expiry // save the requested expiry time for retrieval later + h.db.Save(&m) + respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). @@ -153,8 +185,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } - // The NodeKey we have matches OldNodeKey, which means this is a refresh after an key expiration - if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() { + // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration + if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) { log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). @@ -179,14 +211,19 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // We arrive here after a client is restarted without finalizing the authentication flow or // when headscale is stopped in the middle of the auth process. - if m.Registered { + if m.Registered && m.Expiry.UTC().After(now) { log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map") + + m.NodeKey = wgkey.Key(req.NodeKey).HexString() + h.db.Save(&m) + resp.AuthURL = "" resp.MachineAuthorized = true resp.User = *m.Namespace.toUser() + respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). @@ -210,6 +247,11 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } + + m.Expiry = &req.Expiry // save the requested expiry time for retrieval later + m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the new nodekey + h.db.Save(&m) + respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). diff --git a/app.go b/app.go index 3c4b307..2ad7215 100644 --- a/app.go +++ b/app.go @@ -3,6 +3,9 @@ package headscale import ( "errors" "fmt" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/patrickmn/go-cache" + "golang.org/x/oauth2" "net/http" "os" "strings" @@ -49,6 +52,9 @@ type Config struct { OIDCIssuer string OIDCClientID string OIDCClientSecret string + + MaxMachineExpiry time.Duration + DefaultMachineExpiry time.Duration } // Headscale represents the base app of the service @@ -68,6 +74,10 @@ type Headscale struct { clientsUpdateChannelMutex sync.Mutex lastStateChange sync.Map + + oidcProvider *oidc.Provider + oauth2Config *oauth2.Config + oidcStateCache *cache.Cache } // NewHeadscale returns the Headscale app @@ -107,6 +117,13 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, err } + if cfg.OIDCIssuer != "" { + err = h.initOIDC() + if err != nil { + return nil, err + } + } + return &h, nil } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 6ccdcde..67017aa 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -144,6 +144,16 @@ func getHeadscaleApp() (*headscale.Headscale, error) { return nil, err } + maxMachineExpiry, _ := time.ParseDuration("8h") + if viper.GetDuration("max_machine_expiry") >= time.Second { + maxMachineExpiry = viper.GetDuration("max_machine_expiry") + } + + defaultMachineExpiry, _ := time.ParseDuration("8h") + if viper.GetDuration("default_machine_expiry") >= time.Second { + defaultMachineExpiry = viper.GetDuration("default_machine_expiry") + } + cfg := headscale.Config{ ServerURL: viper.GetString("server_url"), Addr: viper.GetString("listen_addr"), @@ -174,6 +184,9 @@ func getHeadscaleApp() (*headscale.Headscale, error) { OIDCIssuer: viper.GetString("oidc_issuer"), OIDCClientID: viper.GetString("oidc_client_id"), OIDCClientSecret: viper.GetString("oidc_client_secret"), + + MaxMachineExpiry: maxMachineExpiry, + DefaultMachineExpiry: defaultMachineExpiry, } h, err := headscale.NewHeadscale(cfg) diff --git a/oidc.go b/oidc.go index 328731e..1220098 100644 --- a/oidc.go +++ b/oidc.go @@ -13,6 +13,7 @@ import ( "golang.org/x/oauth2" "gorm.io/gorm" "net/http" + "strings" "time" ) @@ -23,9 +24,33 @@ type IDTokenClaims struct { Username string `json:"preferred_username,omitempty"` } -var oidcProvider *oidc.Provider -var oauth2Config *oauth2.Config -var stateCache *cache.Cache +func (h *Headscale) initOIDC() error { + var err error + // grab oidc config if it hasn't been already + if h.oauth2Config == nil { + h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer) + + if err != nil { + log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) + return err + } + + h.oauth2Config = &oauth2.Config{ + ClientID: h.cfg.OIDCClientID, + ClientSecret: h.cfg.OIDCClientSecret, + Endpoint: h.oidcProvider.Endpoint(), + RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + } + + // init the state cache if it hasn't been already + if h.oidcStateCache == nil { + h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10) + } + + return nil +} // RegisterOIDC redirects to the OIDC provider for authentication // Puts machine key in cache so the callback can retrieve it using the oidc state param @@ -37,30 +62,8 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { return } - var err error - - // grab oidc config if it hasn't been already - if oauth2Config == nil { - oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer) - - if err != nil { - log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) - c.String(http.StatusInternalServerError, "Could not retrieve OIDC Config") - return - } - - oauth2Config = &oauth2.Config{ - ClientID: h.cfg.OIDCClientID, - ClientSecret: h.cfg.OIDCClientSecret, - Endpoint: oidcProvider.Endpoint(), - RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - } - - } - b := make([]byte, 16) - _, err = rand.Read(b) + _, err := rand.Read(b) if err != nil { log.Error().Msg("could not read 16 bytes from rand") @@ -70,15 +73,10 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { stateStr := hex.EncodeToString(b)[:32] - // init the state cache if it hasn't been already - if stateCache == nil { - stateCache = cache.New(time.Minute*5, time.Minute*10) - } - // place the machine key into the state cache, so it can be retrieved later - stateCache.Set(stateStr, mKeyStr, time.Minute*5) + h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5) - authUrl := oauth2Config.AuthCodeURL(stateStr) + authUrl := h.oauth2Config.AuthCodeURL(stateStr) log.Debug().Msgf("Redirecting to %s for authentication", authUrl) c.Redirect(http.StatusFound, authUrl) @@ -99,7 +97,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - oauth2Token, err := oauth2Config.Exchange(context.Background(), code) + oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) if err != nil { c.String(http.StatusBadRequest, "Could not exchange code for token") return @@ -111,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - verifier := oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID}) + verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID}) idToken, err := verifier.Verify(context.Background(), rawIDToken) if err != nil { @@ -133,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { } //retrieve machinekey from state cache - mKeyIf, mKeyFound := stateCache.Get(state) + mKeyIf, mKeyFound := h.oidcStateCache.Get(state) if !mKeyFound { c.String(http.StatusBadRequest, "state has expired") @@ -157,6 +155,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { //look for a namespace of the users email for now if !m.Registered { + log.Debug().Msg("Registering new machine after successful callback") + ns, err := h.GetNamespace(claims.Email) if err != nil { ns, err = h.CreateNamespace(claims.Email) @@ -182,6 +182,22 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { h.db.Save(&m) } + if m.isExpired() { + maxExpiry := time.Now().UTC().Add(h.cfg.MaxMachineExpiry) + + // use the maximum expiry if it's sooner than the requested expiry + if maxExpiry.Before(*m.Expiry) { + log.Debug().Msgf("Clamping expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry) + m.Expiry = &maxExpiry + h.db.Save(&m) + } else if m.Expiry.IsZero() { + log.Debug().Msgf("Using default machine expiry time: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry) + defaultExpiry := time.Now().UTC().Add(h.cfg.DefaultMachineExpiry) + m.Expiry = &defaultExpiry + h.db.Save(&m) + } + } + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`