updates from code review
This commit is contained in:
parent
35795c79c3
commit
e407d423d4
4 changed files with 131 additions and 43 deletions
56
api.go
56
api.go
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -83,7 +84,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
||||||
m = Machine{
|
m = Machine{
|
||||||
Expiry: &req.Expiry,
|
Expiry: &time.Time{},
|
||||||
MachineKey: mKey.HexString(),
|
MachineKey: mKey.HexString(),
|
||||||
Name: req.Hostinfo.Hostname,
|
Name: req.Hostinfo.Hostname,
|
||||||
NodeKey: wgkey.Key(req.NodeKey).HexString(),
|
NodeKey: wgkey.Key(req.NodeKey).HexString(),
|
||||||
|
@ -107,7 +108,33 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
|
|
||||||
// We have the updated key!
|
// We have the updated key!
|
||||||
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
||||||
if m.Registered {
|
|
||||||
|
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
|
||||||
|
log.Debug().
|
||||||
|
Str("handler", "Registration").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Client requested logout")
|
||||||
|
|
||||||
|
m.Expiry = &req.Expiry
|
||||||
|
h.db.Save(&m)
|
||||||
|
|
||||||
|
resp.AuthURL = ""
|
||||||
|
resp.MachineAuthorized = false
|
||||||
|
resp.User = *m.Namespace.toUser()
|
||||||
|
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "Registration").
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot encode message")
|
||||||
|
c.String(http.StatusInternalServerError, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Registered && m.Expiry.UTC().After(now) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
|
@ -132,14 +159,19 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Msg("Not registered and not NodeKey rotation. Sending a authurl to register")
|
Msg("Not registered (or expired) and not NodeKey rotation. Sending a authurl to register")
|
||||||
|
|
||||||
if h.cfg.OIDCIssuer != "" {
|
if h.cfg.OIDCIssuer != "" {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||||
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||||
} else {
|
} else {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.Expiry = &req.Expiry // save the requested expiry time for retrieval later
|
||||||
|
h.db.Save(&m)
|
||||||
|
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -153,8 +185,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after an key expiration
|
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||||
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() {
|
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
|
@ -179,14 +211,19 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
|
|
||||||
// We arrive here after a client is restarted without finalizing the authentication flow or
|
// We arrive here after a client is restarted without finalizing the authentication flow or
|
||||||
// when headscale is stopped in the middle of the auth process.
|
// when headscale is stopped in the middle of the auth process.
|
||||||
if m.Registered {
|
if m.Registered && m.Expiry.UTC().After(now) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map")
|
Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map")
|
||||||
|
|
||||||
|
m.NodeKey = wgkey.Key(req.NodeKey).HexString()
|
||||||
|
h.db.Save(&m)
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.MachineAuthorized = true
|
resp.MachineAuthorized = true
|
||||||
resp.User = *m.Namespace.toUser()
|
resp.User = *m.Namespace.toUser()
|
||||||
|
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -210,6 +247,11 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.Expiry = &req.Expiry // save the requested expiry time for retrieval later
|
||||||
|
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the new nodekey
|
||||||
|
h.db.Save(&m)
|
||||||
|
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
|
17
app.go
17
app.go
|
@ -3,6 +3,9 @@ package headscale
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/patrickmn/go-cache"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -49,6 +52,9 @@ type Config struct {
|
||||||
OIDCIssuer string
|
OIDCIssuer string
|
||||||
OIDCClientID string
|
OIDCClientID string
|
||||||
OIDCClientSecret string
|
OIDCClientSecret string
|
||||||
|
|
||||||
|
MaxMachineExpiry time.Duration
|
||||||
|
DefaultMachineExpiry time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Headscale represents the base app of the service
|
// Headscale represents the base app of the service
|
||||||
|
@ -68,6 +74,10 @@ type Headscale struct {
|
||||||
clientsUpdateChannelMutex sync.Mutex
|
clientsUpdateChannelMutex sync.Mutex
|
||||||
|
|
||||||
lastStateChange sync.Map
|
lastStateChange sync.Map
|
||||||
|
|
||||||
|
oidcProvider *oidc.Provider
|
||||||
|
oauth2Config *oauth2.Config
|
||||||
|
oidcStateCache *cache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHeadscale returns the Headscale app
|
// NewHeadscale returns the Headscale app
|
||||||
|
@ -107,6 +117,13 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.OIDCIssuer != "" {
|
||||||
|
err = h.initOIDC()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &h, nil
|
return &h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -144,6 +144,16 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
maxMachineExpiry, _ := time.ParseDuration("8h")
|
||||||
|
if viper.GetDuration("max_machine_expiry") >= time.Second {
|
||||||
|
maxMachineExpiry = viper.GetDuration("max_machine_expiry")
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultMachineExpiry, _ := time.ParseDuration("8h")
|
||||||
|
if viper.GetDuration("default_machine_expiry") >= time.Second {
|
||||||
|
defaultMachineExpiry = viper.GetDuration("default_machine_expiry")
|
||||||
|
}
|
||||||
|
|
||||||
cfg := headscale.Config{
|
cfg := headscale.Config{
|
||||||
ServerURL: viper.GetString("server_url"),
|
ServerURL: viper.GetString("server_url"),
|
||||||
Addr: viper.GetString("listen_addr"),
|
Addr: viper.GetString("listen_addr"),
|
||||||
|
@ -174,6 +184,9 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
OIDCIssuer: viper.GetString("oidc_issuer"),
|
OIDCIssuer: viper.GetString("oidc_issuer"),
|
||||||
OIDCClientID: viper.GetString("oidc_client_id"),
|
OIDCClientID: viper.GetString("oidc_client_id"),
|
||||||
OIDCClientSecret: viper.GetString("oidc_client_secret"),
|
OIDCClientSecret: viper.GetString("oidc_client_secret"),
|
||||||
|
|
||||||
|
MaxMachineExpiry: maxMachineExpiry,
|
||||||
|
DefaultMachineExpiry: defaultMachineExpiry,
|
||||||
}
|
}
|
||||||
|
|
||||||
h, err := headscale.NewHeadscale(cfg)
|
h, err := headscale.NewHeadscale(cfg)
|
||||||
|
|
88
oidc.go
88
oidc.go
|
@ -13,6 +13,7 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,9 +24,33 @@ type IDTokenClaims struct {
|
||||||
Username string `json:"preferred_username,omitempty"`
|
Username string `json:"preferred_username,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var oidcProvider *oidc.Provider
|
func (h *Headscale) initOIDC() error {
|
||||||
var oauth2Config *oauth2.Config
|
var err error
|
||||||
var stateCache *cache.Cache
|
// grab oidc config if it hasn't been already
|
||||||
|
if h.oauth2Config == nil {
|
||||||
|
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
h.oauth2Config = &oauth2.Config{
|
||||||
|
ClientID: h.cfg.OIDCClientID,
|
||||||
|
ClientSecret: h.cfg.OIDCClientSecret,
|
||||||
|
Endpoint: h.oidcProvider.Endpoint(),
|
||||||
|
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// init the state cache if it hasn't been already
|
||||||
|
if h.oidcStateCache == nil {
|
||||||
|
h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||||
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
||||||
|
@ -37,30 +62,8 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// grab oidc config if it hasn't been already
|
|
||||||
if oauth2Config == nil {
|
|
||||||
oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
|
||||||
c.String(http.StatusInternalServerError, "Could not retrieve OIDC Config")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
oauth2Config = &oauth2.Config{
|
|
||||||
ClientID: h.cfg.OIDCClientID,
|
|
||||||
ClientSecret: h.cfg.OIDCClientSecret,
|
|
||||||
Endpoint: oidcProvider.Endpoint(),
|
|
||||||
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
|
|
||||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
_, err = rand.Read(b)
|
_, err := rand.Read(b)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msg("could not read 16 bytes from rand")
|
log.Error().Msg("could not read 16 bytes from rand")
|
||||||
|
@ -70,15 +73,10 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||||
|
|
||||||
stateStr := hex.EncodeToString(b)[:32]
|
stateStr := hex.EncodeToString(b)[:32]
|
||||||
|
|
||||||
// init the state cache if it hasn't been already
|
|
||||||
if stateCache == nil {
|
|
||||||
stateCache = cache.New(time.Minute*5, time.Minute*10)
|
|
||||||
}
|
|
||||||
|
|
||||||
// place the machine key into the state cache, so it can be retrieved later
|
// place the machine key into the state cache, so it can be retrieved later
|
||||||
stateCache.Set(stateStr, mKeyStr, time.Minute*5)
|
h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5)
|
||||||
|
|
||||||
authUrl := oauth2Config.AuthCodeURL(stateStr)
|
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
|
||||||
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
|
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, authUrl)
|
c.Redirect(http.StatusFound, authUrl)
|
||||||
|
@ -99,7 +97,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
oauth2Token, err := oauth2Config.Exchange(context.Background(), code)
|
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusBadRequest, "Could not exchange code for token")
|
c.String(http.StatusBadRequest, "Could not exchange code for token")
|
||||||
return
|
return
|
||||||
|
@ -111,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
verifier := oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID})
|
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID})
|
||||||
|
|
||||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -133,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
//retrieve machinekey from state cache
|
//retrieve machinekey from state cache
|
||||||
mKeyIf, mKeyFound := stateCache.Get(state)
|
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
|
||||||
|
|
||||||
if !mKeyFound {
|
if !mKeyFound {
|
||||||
c.String(http.StatusBadRequest, "state has expired")
|
c.String(http.StatusBadRequest, "state has expired")
|
||||||
|
@ -157,6 +155,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
//look for a namespace of the users email for now
|
//look for a namespace of the users email for now
|
||||||
if !m.Registered {
|
if !m.Registered {
|
||||||
|
|
||||||
|
log.Debug().Msg("Registering new machine after successful callback")
|
||||||
|
|
||||||
ns, err := h.GetNamespace(claims.Email)
|
ns, err := h.GetNamespace(claims.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ns, err = h.CreateNamespace(claims.Email)
|
ns, err = h.CreateNamespace(claims.Email)
|
||||||
|
@ -182,6 +182,22 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
h.db.Save(&m)
|
h.db.Save(&m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.isExpired() {
|
||||||
|
maxExpiry := time.Now().UTC().Add(h.cfg.MaxMachineExpiry)
|
||||||
|
|
||||||
|
// use the maximum expiry if it's sooner than the requested expiry
|
||||||
|
if maxExpiry.Before(*m.Expiry) {
|
||||||
|
log.Debug().Msgf("Clamping expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry)
|
||||||
|
m.Expiry = &maxExpiry
|
||||||
|
h.db.Save(&m)
|
||||||
|
} else if m.Expiry.IsZero() {
|
||||||
|
log.Debug().Msgf("Using default machine expiry time: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry)
|
||||||
|
defaultExpiry := time.Now().UTC().Add(h.cfg.DefaultMachineExpiry)
|
||||||
|
m.Expiry = &defaultExpiry
|
||||||
|
h.db.Save(&m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
|
|
Loading…
Reference in a new issue