diff --git a/api.go b/api.go index 2c5a132..fb54f3c 100644 --- a/api.go +++ b/api.go @@ -134,7 +134,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("machine", m.Name). Msg("Not registered and not NodeKey rotation. Sending a authurl to register") - if h.cfg.OIDCEndpoint != "" { + if h.cfg.OIDCIssuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", h.cfg.ServerURL, mKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", @@ -204,7 +204,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Str("machine", m.Name). Msg("The node is sending us a new NodeKey, sending auth url") - if h.cfg.OIDCEndpoint != "" { + if h.cfg.OIDCIssuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", h.cfg.ServerURL, mKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", diff --git a/app.go b/app.go index 6f24c82..3c4b307 100644 --- a/app.go +++ b/app.go @@ -46,7 +46,7 @@ type Config struct { DNSConfig *tailcfg.DNSConfig - OIDCEndpoint string + OIDCIssuer string OIDCClientID string OIDCClientSecret string } @@ -172,11 +172,11 @@ func (h *Headscale) Serve() error { r.GET("/register", h.RegisterWebAPI) r.POST("/machine/:id/map", h.PollNetMapHandler) r.POST("/machine/:id", h.RegistrationHandler) - r.GET("/oidc/register/:mKey", h.RegisterOIDC) + r.GET("/oidc/register/:mkey", h.RegisterOIDC) r.GET("/oidc/callback", h.OIDCCallback) r.GET("/apple", h.AppleMobileConfig) r.GET("/apple/:platform", h.ApplePlatformConfig) - + var err error timeout := 30 * time.Second diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index b7faad5..6ccdcde 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -171,7 +171,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { DNSConfig: GetDNSConfig(), - OIDCEndpoint: viper.GetString("oidc_endpoint"), + OIDCIssuer: viper.GetString("oidc_issuer"), OIDCClientID: viper.GetString("oidc_client_id"), OIDCClientSecret: viper.GetString("oidc_client_secret"), } diff --git a/go.mod b/go.mod index c1d4561..a770338 100644 --- a/go.mod +++ b/go.mod @@ -7,11 +7,12 @@ require ( github.com/Microsoft/go-winio v0.5.0 // indirect github.com/cenkalti/backoff/v4 v4.1.1 // indirect github.com/containerd/continuity v0.1.0 // indirect + github.com/coreos/go-oidc/v3 v3.1.0 github.com/docker/cli v20.10.8+incompatible // indirect github.com/docker/docker v20.10.8+incompatible // indirect github.com/efekarakus/termcolor v1.0.1 github.com/gin-gonic/gin v1.7.4 - github.com/gofrs/uuid v4.0.0+incompatible // indirect + github.com/gofrs/uuid v4.0.0+incompatible github.com/google/go-github v17.0.0+incompatible // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b @@ -28,13 +29,13 @@ require ( github.com/spf13/viper v1.8.1 github.com/stretchr/testify v1.7.0 github.com/tailscale/hujson v0.0.0-20210818175511-7360507a6e88 - github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e // indirect + github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 golang.org/x/net v0.0.0-20210913180222-943fd674d43e // indirect + golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602 golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c - gopkg.in/square/go-jose.v2 v2.3.1 gopkg.in/yaml.v2 v2.4.0 gorm.io/datatypes v1.0.2 gorm.io/driver/postgres v1.1.1 diff --git a/go.sum b/go.sum index 9a76d17..fc498e7 100644 --- a/go.sum +++ b/go.sum @@ -143,6 +143,8 @@ github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkE github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-oidc/v3 v3.1.0 h1:6avEvcdvTa1qYsOZ6I5PRkSYHzpTNWgKYmaJfaYbrRw= +github.com/coreos/go-oidc/v3 v3.1.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -1067,6 +1069,7 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200505041828-1ed23360d12c/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= @@ -1101,6 +1104,7 @@ golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602 h1:0Ja1LBD+yisY6RWM/BH7TJVXWsSjs2VwBSmvSX4HdBc= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1355,6 +1359,7 @@ google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -1452,8 +1457,9 @@ gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU= gopkg.in/ini.v1 v1.62.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4= gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w= +gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= diff --git a/machine.go b/machine.go index 1d4939c..b5c821f 100644 --- a/machine.go +++ b/machine.go @@ -50,6 +50,11 @@ func (m Machine) isAlreadyRegistered() bool { return m.Registered } +// isExpired returns whether the machine registration has expired +func (m Machine) isExpired() bool { + return time.Now().UTC().After(*m.Expiry) +} + // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // as per the expected behaviour in the official SaaS func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) { diff --git a/oidc.go b/oidc.go index dabd8b0..aa80911 100644 --- a/oidc.go +++ b/oidc.go @@ -1,186 +1,37 @@ package headscale import ( + "context" "crypto/rand" "encoding/hex" - "encoding/json" "errors" "fmt" + "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" - "github.com/s12v/go-jwks" - "gopkg.in/square/go-jose.v2/jwt" + "golang.org/x/oauth2" "gorm.io/gorm" - "io" "net/http" - "net/url" - "strings" "time" ) -type OpenIDConfiguration struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - JWKSURI string `json:"jwks_uri"` -} - -type OpenIDTokens struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - IdToken string `json:"id_token"` - NotBeforePolicy int `json:"not-before-policy,omitempty"` - RefreshExpiresIn int `json:"refresh_expires_in"` - RefreshToken string `json:"refresh_token"` - Scope string `json:"scope"` - SessionState string `json:"session_state,omitempty"` - TokenType string `json:"token_type,omitempty"` -} - -type AccessToken struct { - jwt.Claims +type IDTokenClaims struct { Name string `json:"name,omitempty"` Groups []string `json:"groups,omitempty"` Email string `json:"email"` Username string `json:"preferred_username,omitempty"` } -var oidcConfig *OpenIDConfiguration +var oidcProvider *oidc.Provider +var oauth2Config *oauth2.Config var stateCache *cache.Cache -var jwksSource *jwks.WebSource -var jwksClient jwks.JWKSClient - -func verifyToken(token string) (*AccessToken, error) { - - if jwksClient == nil { - jwksSource = jwks.NewWebSource(oidcConfig.JWKSURI) - jwksClient = jwks.NewDefaultClient( - jwksSource, - time.Hour, // Refresh keys every 1 hour - 12*time.Hour, // Expire keys after 12 hours - ) - } - - //decode jwt - tok, err := jwt.ParseSigned(token) - if err != nil { - return nil, err - } - - if tok.Headers[0].KeyID != "" { - log.Debug().Msgf("Checking KID %s\n", tok.Headers[0].KeyID) - - jwk, err := jwksClient.GetSignatureKey(tok.Headers[0].KeyID) - if err != nil { - return nil, err - } - - claims := AccessToken{} - - err = tok.Claims(jwk.Certificates[0].PublicKey, &claims) - if err != nil { - return nil, err - } else { - - err = claims.Validate(jwt.Expected{ - Time: time.Now(), - }) - if err != nil { - return nil, err - } - - return &claims, nil - } - - } else { - return nil, errors.New("JWT does not contain a key id") - } -} - -func getOIDCConfig(oidcConfigURL string) (*OpenIDConfiguration, error) { - client := &http.Client{} - req, err := http.NewRequest("GET", oidcConfigURL, nil) - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - - log.Debug().Msgf("Requesting OIDC Config from %s", oidcConfigURL) - - oidcConfigResp, err := client.Do(req) - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - defer oidcConfigResp.Body.Close() - - var oidcConfig OpenIDConfiguration - - err = json.NewDecoder(oidcConfigResp.Body).Decode(&oidcConfig) - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - return &oidcConfig, nil -} - -func (h *Headscale) exchangeCodeForTokens(code string, redirectURI string) (*OpenIDTokens, error) { - var err error - - if oidcConfig == nil { - oidcConfig, err = getOIDCConfig(fmt.Sprintf("%s.well-known/openid-configuration", h.cfg.OIDCEndpoint)) - if err != nil { - return nil, err - } - } - - params := url.Values{} - params.Add("grant_type", "authorization_code") - params.Add("code", code) - params.Add("client_id", h.cfg.OIDCClientID) - params.Add("client_secret", h.cfg.OIDCClientSecret) - params.Add("redirect_uri", redirectURI) - - client := &http.Client{} - req, err := http.NewRequest("POST", oidcConfig.TokenEndpoint, strings.NewReader(params.Encode())) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - - tokenResp, err := client.Do(req) - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - defer tokenResp.Body.Close() - - if tokenResp.StatusCode != 200 { - b, _ := io.ReadAll(tokenResp.Body) - log.Error().Msgf("%s", b) - } - - var tokens OpenIDTokens - - err = json.NewDecoder(tokenResp.Body).Decode(&tokens) - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - - log.Info().Msg("Successfully exchanged code for tokens") - - return &tokens, nil -} // RegisterOIDC redirects to the OIDC provider for authentication // Puts machine key in cache so the callback can retrieve it using the oidc state param // Listens in /oidc/register/:mKey func (h *Headscale) RegisterOIDC(c *gin.Context) { - mKeyStr := c.Param("mKey") + mKeyStr := c.Param("mkey") if mKeyStr == "" { c.String(http.StatusBadRequest, "Wrong params") return @@ -189,13 +40,23 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { var err error // grab oidc config if it hasn't been already - if oidcConfig == nil { - oidcConfig, err = getOIDCConfig(fmt.Sprintf("%s.well-known/openid-configuration", h.cfg.OIDCEndpoint)) + if oauth2Config == nil { + oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer) if err != nil { + log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) c.String(http.StatusInternalServerError, "Could not retrieve OIDC Config") return } + + oauth2Config = &oauth2.Config{ + ClientID: h.cfg.OIDCClientID, + ClientSecret: h.cfg.OIDCClientSecret, + Endpoint: oidcProvider.Endpoint(), + RedirectURL: fmt.Sprintf("%s/oidc/callback", h.cfg.ServerURL), + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + } b := make([]byte, 16) @@ -217,21 +78,16 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { // place the machine key into the state cache, so it can be retrieved later stateCache.Set(stateStr, mKeyStr, time.Minute*5) - params := url.Values{} - params.Add("response_type", "code") - params.Add("client_id", h.cfg.OIDCClientID) - params.Add("redirect_uri", fmt.Sprintf("%s/oidc/callback", h.cfg.ServerURL)) - params.Add("scope", "openid") - params.Add("state", stateStr) - - authUrl := fmt.Sprintf("%s?%s", oidcConfig.AuthorizationEndpoint, params.Encode()) - log.Debug().Msg(authUrl) + authUrl := oauth2Config.AuthCodeURL(stateStr) + log.Debug().Msgf("Redirecting to %s for authentication", authUrl) c.Redirect(http.StatusFound, authUrl) } // OIDCCallback handles the callback from the OIDC endpoint -// Retrieves the mkey from the state cache, if the machine is not registered, presents a confirmation +// Retrieves the mkey from the state cache and adds the machine to the users email namespace +// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities +// TODO: Add groups information from OIDC tokens into machine HostInfo // Listens in /oidc/callback func (h *Headscale) OIDCCallback(c *gin.Context) { @@ -243,20 +99,36 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - redirectURI := fmt.Sprintf("%s/oidc/callback", h.cfg.ServerURL) - - tokens, err := h.exchangeCodeForTokens(code, redirectURI) - + oauth2Token, err := oauth2Config.Exchange(context.Background(), code) if err != nil { c.String(http.StatusBadRequest, "Could not exchange code for token") return } - //verify tokens - claims, err := verifyToken(tokens.AccessToken) + rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) + if !rawIDTokenOK { + c.String(http.StatusBadRequest, "Could not extract ID Token") + return + } + verifier := oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID}) + + idToken, err := verifier.Verify(context.Background(), rawIDToken) if err != nil { - c.String(http.StatusBadRequest, "invalid tokens") + c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) + return + } + + //userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token)) + //if err != nil { + // c.String(http.StatusBadRequest, "Failed to retrieve userinfo: "+err.Error()) + // return + //} + + // Extract custom claims + var claims IDTokenClaims + if err = idToken.Claims(&claims); err != nil { + c.String(http.StatusBadRequest, "Failed to decode id token claims: "+err.Error()) return }