updates from code review

This commit is contained in:
Raal Goff 2021-10-08 17:43:52 +08:00
parent 35795c79c3
commit e407d423d4
4 changed files with 131 additions and 43 deletions

56
api.go
View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -83,7 +84,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
m = Machine{ m = Machine{
Expiry: &req.Expiry, Expiry: &time.Time{},
MachineKey: mKey.HexString(), MachineKey: mKey.HexString(),
Name: req.Hostinfo.Hostname, Name: req.Hostinfo.Hostname,
NodeKey: wgkey.Key(req.NodeKey).HexString(), NodeKey: wgkey.Key(req.NodeKey).HexString(),
@ -107,7 +108,33 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// We have the updated key! // We have the updated key!
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
if m.Registered {
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("Client requested logout")
m.Expiry = &req.Expiry
h.db.Save(&m)
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
return
}
if m.Registered && m.Expiry.UTC().After(now) {
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", m.Name).
@ -132,14 +159,19 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", m.Name).
Msg("Not registered and not NodeKey rotation. Sending a authurl to register") Msg("Not registered (or expired) and not NodeKey rotation. Sending a authurl to register")
if h.cfg.OIDCIssuer != "" { if h.cfg.OIDCIssuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
} }
m.Expiry = &req.Expiry // save the requested expiry time for retrieval later
h.db.Save(&m)
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -153,8 +185,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return return
} }
// The NodeKey we have matches OldNodeKey, which means this is a refresh after an key expiration // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() { if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", m.Name).
@ -179,14 +211,19 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// We arrive here after a client is restarted without finalizing the authentication flow or // We arrive here after a client is restarted without finalizing the authentication flow or
// when headscale is stopped in the middle of the auth process. // when headscale is stopped in the middle of the auth process.
if m.Registered { if m.Registered && m.Expiry.UTC().After(now) {
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", m.Name).
Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map") Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map")
m.NodeKey = wgkey.Key(req.NodeKey).HexString()
h.db.Save(&m)
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *m.Namespace.toUser() resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -210,6 +247,11 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
} }
m.Expiry = &req.Expiry // save the requested expiry time for retrieval later
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the new nodekey
h.db.Save(&m)
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().

17
app.go
View file

@ -3,6 +3,9 @@ package headscale
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"net/http" "net/http"
"os" "os"
"strings" "strings"
@ -49,6 +52,9 @@ type Config struct {
OIDCIssuer string OIDCIssuer string
OIDCClientID string OIDCClientID string
OIDCClientSecret string OIDCClientSecret string
MaxMachineExpiry time.Duration
DefaultMachineExpiry time.Duration
} }
// Headscale represents the base app of the service // Headscale represents the base app of the service
@ -68,6 +74,10 @@ type Headscale struct {
clientsUpdateChannelMutex sync.Mutex clientsUpdateChannelMutex sync.Mutex
lastStateChange sync.Map lastStateChange sync.Map
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
oidcStateCache *cache.Cache
} }
// NewHeadscale returns the Headscale app // NewHeadscale returns the Headscale app
@ -107,6 +117,13 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, err return nil, err
} }
if cfg.OIDCIssuer != "" {
err = h.initOIDC()
if err != nil {
return nil, err
}
}
return &h, nil return &h, nil
} }

View file

@ -144,6 +144,16 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
return nil, err return nil, err
} }
maxMachineExpiry, _ := time.ParseDuration("8h")
if viper.GetDuration("max_machine_expiry") >= time.Second {
maxMachineExpiry = viper.GetDuration("max_machine_expiry")
}
defaultMachineExpiry, _ := time.ParseDuration("8h")
if viper.GetDuration("default_machine_expiry") >= time.Second {
defaultMachineExpiry = viper.GetDuration("default_machine_expiry")
}
cfg := headscale.Config{ cfg := headscale.Config{
ServerURL: viper.GetString("server_url"), ServerURL: viper.GetString("server_url"),
Addr: viper.GetString("listen_addr"), Addr: viper.GetString("listen_addr"),
@ -174,6 +184,9 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
OIDCIssuer: viper.GetString("oidc_issuer"), OIDCIssuer: viper.GetString("oidc_issuer"),
OIDCClientID: viper.GetString("oidc_client_id"), OIDCClientID: viper.GetString("oidc_client_id"),
OIDCClientSecret: viper.GetString("oidc_client_secret"), OIDCClientSecret: viper.GetString("oidc_client_secret"),
MaxMachineExpiry: maxMachineExpiry,
DefaultMachineExpiry: defaultMachineExpiry,
} }
h, err := headscale.NewHeadscale(cfg) h, err := headscale.NewHeadscale(cfg)

88
oidc.go
View file

@ -13,6 +13,7 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gorm.io/gorm" "gorm.io/gorm"
"net/http" "net/http"
"strings"
"time" "time"
) )
@ -23,9 +24,33 @@ type IDTokenClaims struct {
Username string `json:"preferred_username,omitempty"` Username string `json:"preferred_username,omitempty"`
} }
var oidcProvider *oidc.Provider func (h *Headscale) initOIDC() error {
var oauth2Config *oauth2.Config var err error
var stateCache *cache.Cache // grab oidc config if it hasn't been already
if h.oauth2Config == nil {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
return err
}
h.oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDCClientID,
ClientSecret: h.cfg.OIDCClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
}
// init the state cache if it hasn't been already
if h.oidcStateCache == nil {
h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10)
}
return nil
}
// RegisterOIDC redirects to the OIDC provider for authentication // RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param // Puts machine key in cache so the callback can retrieve it using the oidc state param
@ -37,30 +62,8 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
return return
} }
var err error
// grab oidc config if it hasn't been already
if oauth2Config == nil {
oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
c.String(http.StatusInternalServerError, "Could not retrieve OIDC Config")
return
}
oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDCClientID,
ClientSecret: h.cfg.OIDCClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
}
b := make([]byte, 16) b := make([]byte, 16)
_, err = rand.Read(b) _, err := rand.Read(b)
if err != nil { if err != nil {
log.Error().Msg("could not read 16 bytes from rand") log.Error().Msg("could not read 16 bytes from rand")
@ -70,15 +73,10 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
stateStr := hex.EncodeToString(b)[:32] stateStr := hex.EncodeToString(b)[:32]
// init the state cache if it hasn't been already
if stateCache == nil {
stateCache = cache.New(time.Minute*5, time.Minute*10)
}
// place the machine key into the state cache, so it can be retrieved later // place the machine key into the state cache, so it can be retrieved later
stateCache.Set(stateStr, mKeyStr, time.Minute*5) h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5)
authUrl := oauth2Config.AuthCodeURL(stateStr) authUrl := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authUrl) log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
c.Redirect(http.StatusFound, authUrl) c.Redirect(http.StatusFound, authUrl)
@ -99,7 +97,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
return return
} }
oauth2Token, err := oauth2Config.Exchange(context.Background(), code) oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
if err != nil { if err != nil {
c.String(http.StatusBadRequest, "Could not exchange code for token") c.String(http.StatusBadRequest, "Could not exchange code for token")
return return
@ -111,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
return return
} }
verifier := oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID}) verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID})
idToken, err := verifier.Verify(context.Background(), rawIDToken) idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil { if err != nil {
@ -133,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
} }
//retrieve machinekey from state cache //retrieve machinekey from state cache
mKeyIf, mKeyFound := stateCache.Get(state) mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
if !mKeyFound { if !mKeyFound {
c.String(http.StatusBadRequest, "state has expired") c.String(http.StatusBadRequest, "state has expired")
@ -157,6 +155,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
//look for a namespace of the users email for now //look for a namespace of the users email for now
if !m.Registered { if !m.Registered {
log.Debug().Msg("Registering new machine after successful callback")
ns, err := h.GetNamespace(claims.Email) ns, err := h.GetNamespace(claims.Email)
if err != nil { if err != nil {
ns, err = h.CreateNamespace(claims.Email) ns, err = h.CreateNamespace(claims.Email)
@ -182,6 +182,22 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
h.db.Save(&m) h.db.Save(&m)
} }
if m.isExpired() {
maxExpiry := time.Now().UTC().Add(h.cfg.MaxMachineExpiry)
// use the maximum expiry if it's sooner than the requested expiry
if maxExpiry.Before(*m.Expiry) {
log.Debug().Msgf("Clamping expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry)
m.Expiry = &maxExpiry
h.db.Save(&m)
} else if m.Expiry.IsZero() {
log.Debug().Msgf("Using default machine expiry time: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry)
defaultExpiry := time.Now().UTC().Add(h.cfg.DefaultMachineExpiry)
m.Expiry = &defaultExpiry
h.db.Save(&m)
}
}
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html> <html>
<body> <body>