headscale/hscontrol/oidc.go
fen4o 9d58489903 Add OIDC claim names options
Some identity providers (auth0 for example) do not allow to set the
groups claims and administrators must use custom claims names and add
them in the id token.

This commit adds the following configuration options:

- `oidc.groups_claim` to set the groups claim name
- `oidc.email_claim` to set the email claim name

All claims default to the previous values for backwards compatibility.

The groups claim can now also accept `[]string` or `string` as some
providers might return only a string response instead of array.
2023-11-08 16:00:07 +02:00

753 lines
20 KiB
Go

package hscontrol
import (
"bytes"
"context"
"crypto/rand"
_ "embed"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"html/template"
"net/http"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"tailscale.com/types/key"
)
const (
randomByteSize = 16
)
var (
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
errOIDCAllowedDomains = errors.New(
"authenticated principal does not match any allowed domain",
)
errOIDCAllowedGroups = errors.New("authenticated principal is not in any allowed group")
errOIDCAllowedUsers = errors.New(
"authenticated principal does not match any allowed user",
)
errOIDCInvalidNodeState = errors.New(
"requested node state key expired before authorisation completed",
)
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
errOIDCEmailClaimMissing = errors.New("email claim missing from ID Token")
)
type IDTokenClaims struct {
// in some cases the groups might be a single value and not a list
Groups stringOrArray
Email string
}
type stringOrArray []string
func (s *stringOrArray) UnmarshalJSON(b []byte) error {
var a []string
if err := json.Unmarshal(b, &a); err == nil {
*s = a
return nil
}
var str string
if err := json.Unmarshal(b, &str); err != nil {
return err
}
*s = []string{str}
return nil
}
type rawClaims map[string]json.RawMessage
func (c rawClaims) unmarshalClaim(name string, v interface{}) error {
val, ok := c[name]
if !ok {
return fmt.Errorf("claim not present")
}
return json.Unmarshal([]byte(val), v)
}
func (c rawClaims) hasClaim(name string) bool {
_, ok := c[name]
return ok
}
func (h *Headscale) initOIDC() error {
var err error
// grab oidc config if it hasn't been already
if h.oauth2Config == nil {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
if err != nil {
log.Error().
Err(err).
Caller().
Msgf("Could not retrieve OIDC Config: %s", err.Error())
return err
}
h.oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
),
Scopes: h.cfg.OIDC.Scope,
}
}
return nil
}
func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.Time {
if h.cfg.OIDC.UseExpiryFromToken {
return idTokenExpiration
}
return time.Now().Add(h.cfg.OIDC.Expiry)
}
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:nKey.
func (h *Headscale) RegisterOIDC(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
nodeKeyStr, ok := vars["nkey"]
log.Debug().
Caller().
Str("node_key", nodeKeyStr).
Bool("ok", ok).
Msg("Received oidc register call")
if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText(
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)),
)
if !ok || nodeKeyStr == "" || err != nil {
log.Warn().
Err(err).
Msg("Failed to parse incoming nodekey in OIDC registration")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil {
util.LogErr(err, "could not read 16 bytes from rand")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(
stateStr,
util.NodePublicKeyStripPrefix(nodeKey),
registerCacheExpiration,
)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
for k, v := range h.cfg.OIDC.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
http.Redirect(writer, req, authURL, http.StatusFound)
}
type oidcCallbackTemplateConfig struct {
User string
Verb string
}
//go:embed assets/oidc_callback_template.html
var oidcCallbackTemplateContent string
var oidcCallbackTemplate = template.Must(
template.New("oidccallback").Parse(oidcCallbackTemplateContent),
)
// OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the nkey from the state cache and adds the node to the users email user
// TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into node HostInfo
// Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(
writer http.ResponseWriter,
req *http.Request,
) {
code, state, err := validateOIDCCallbackParams(writer, req)
if err != nil {
return
}
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
if err != nil {
return
}
idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken)
if err != nil {
return
}
idTokenExpiry := h.determineTokenExpiration(idToken.Expiry)
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
// userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
// if err != nil {
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo"))
// return
// }
claims, err := extractIDTokenClaims(writer, h.cfg.OIDC, idToken)
if err != nil {
return
}
if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil {
return
}
if err := validateOIDCAllowedGroups(writer, h.cfg.OIDC.AllowedGroups, claims); err != nil {
return
}
if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil {
return
}
nodeKey, nodeExists, err := h.validateNodeForOIDCCallback(
writer,
state,
claims,
idTokenExpiry,
)
if err != nil || nodeExists {
return
}
userName, err := getUserName(writer, claims, h.cfg.OIDC.StripEmaildomain)
if err != nil {
return
}
// register the node if it's new
log.Debug().Msg("Registering new node after successful callback")
user, err := h.findOrCreateNewUserForOIDCCallback(writer, userName)
if err != nil {
return
}
if err := h.registerNodeForOIDCCallback(writer, user, nodeKey, idTokenExpiry); err != nil {
return
}
content, err := renderOIDCCallbackTemplate(writer, claims)
if err != nil {
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil {
util.LogErr(err, "Failed to write response")
}
}
func validateOIDCCallbackParams(
writer http.ResponseWriter,
req *http.Request,
) (string, string, error) {
code := req.URL.Query().Get("code")
state := req.URL.Query().Get("state")
if code == "" || state == "" {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return "", "", errEmptyOIDCCallbackParams
}
return code, state, nil
}
func (h *Headscale) getIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
code, state string,
) (string, error) {
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
if err != nil {
util.LogErr(err, "Could not exchange code for token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Could not exchange code for token"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return "", err
}
log.Trace().
Caller().
Str("code", code).
Str("state", state).
Msg("Got oidc callback")
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Could not extract ID Token"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return "", errNoOIDCIDToken
}
return rawIDToken, nil
}
func (h *Headscale) verifyIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
rawIDToken string,
) (*oidc.IDToken, error) {
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
util.LogErr(err, "failed to verify id token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to verify id token"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err
}
return idToken, nil
}
func extractIDTokenClaims(
writer http.ResponseWriter,
cfg types.OIDCConfig,
idToken *oidc.IDToken,
) (*IDTokenClaims, error) {
var claims IDTokenClaims
var rawClaims rawClaims
if err := idToken.Claims(&rawClaims); err != nil {
handleClaimError(writer, err)
return nil, err
}
if !rawClaims.hasClaim(cfg.EmailClaim) {
handleClaimError(writer, errOIDCEmailClaimMissing)
return nil, errOIDCEmailClaimMissing
}
if err := rawClaims.unmarshalClaim(cfg.EmailClaim, &claims.Email); err != nil {
handleClaimError(writer, err)
return nil, err
}
if rawClaims.hasClaim(cfg.GroupsClaim) {
if err := rawClaims.unmarshalClaim(cfg.GroupsClaim, &claims.Groups); err != nil {
handleClaimError(writer, err)
return nil, err
}
}
return &claims, nil
}
func handleClaimError(writer http.ResponseWriter, err error) {
util.LogErr(err, "Failed to decode id token rawClaims")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token rawClaims"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
}
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains(
writer http.ResponseWriter,
allowedDomains []string,
claims *IDTokenClaims,
) error {
if len(allowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!util.IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
log.Trace().Msg("authenticated principal does not match any allowed domain")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (domain mismatch)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedDomains
}
}
return nil
}
// validateOIDCAllowedGroups checks if AllowedGroups is provided,
// and that the user has one group in the list.
// claims.Groups can be populated by adding a client scope named
// 'groups' that contains group membership.
func validateOIDCAllowedGroups(
writer http.ResponseWriter,
allowedGroups []string,
claims *IDTokenClaims,
) error {
if len(allowedGroups) > 0 {
for _, group := range allowedGroups {
if util.IsStringInSlice(claims.Groups, group) {
return nil
}
}
log.Trace().Msg("authenticated principal not in any allowed groups")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (allowed groups)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedGroups
}
return nil
}
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
// that the authenticated principal is part of that list.
func validateOIDCAllowedUsers(
writer http.ResponseWriter,
allowedUsers []string,
claims *IDTokenClaims,
) error {
if len(allowedUsers) > 0 &&
!util.IsStringInSlice(allowedUsers, claims.Email) {
log.Trace().Msg("authenticated principal does not match any allowed user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (user mismatch)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedUsers
}
return nil
}
// validateNode retrieves node information if it exist
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
func (h *Headscale) validateNodeForOIDCCallback(
writer http.ResponseWriter,
state string,
claims *IDTokenClaims,
expiry time.Time,
) (*key.NodePublic, bool, error) {
// retrieve nodekey from state cache
nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
if !nodeKeyFound {
log.Trace().
Msg("requested node state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCNodeKeyMissing
}
var nodeKey key.NodePublic
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
if !nodeKeyOK {
log.Trace().
Msg("requested node state key is not a string")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCInvalidNodeState
}
err := nodeKey.UnmarshalText(
[]byte(util.NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
)
if err != nil {
log.Error().
Str("nodeKey", nodeKeyFromCache).
Bool("nodeKeyOK", nodeKeyOK).
Msg("could not parse node public key")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("could not parse node public key"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, err
}
// retrieve node information if it exist
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := h.db.GetNodeByNodeKey(nodeKey)
if node != nil {
log.Trace().
Caller().
Str("node", node.Hostname).
Msg("node already registered, reauthenticating")
err := h.db.NodeSetExpiry(node, expiry)
if err != nil {
util.LogErr(err, "Failed to refresh node")
http.Error(
writer,
"Failed to refresh node",
http.StatusInternalServerError,
)
return nil, true, err
}
log.Debug().
Str("node", node.Hostname).
Str("expiresAt", fmt.Sprintf("%v", expiry)).
Msg("successfully refreshed node")
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email,
Verb: "Reauthenticated",
}); err != nil {
log.Error().
Str("func", "OIDCCallback").
Str("type", "reauthenticate").
Err(err).
Msg("Could not render OIDC callback template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, true, err
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, true, nil
}
return &nodeKey, false, nil
}
func getUserName(
writer http.ResponseWriter,
claims *IDTokenClaims,
stripEmaildomain bool,
) (string, error) {
userName, err := util.NormalizeToFQDNRules(
claims.Email,
stripEmaildomain,
)
if err != nil {
util.LogErr(err, "couldn't normalize email")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("couldn't normalize email"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return "", err
}
return userName, nil
}
func (h *Headscale) findOrCreateNewUserForOIDCCallback(
writer http.ResponseWriter,
userName string,
) (*types.User, error) {
user, err := h.db.GetUser(userName)
if errors.Is(err, db.ErrUserNotFound) {
user, err = h.db.CreateUser(userName)
if err != nil {
log.Error().
Err(err).
Caller().
Msgf("could not create new user '%s'", userName)
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not create user"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err
}
} else if err != nil {
log.Error().
Caller().
Err(err).
Str("user", userName).
Msg("could not find or create user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not find or create user"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err
}
return user, nil
}
func (h *Headscale) registerNodeForOIDCCallback(
writer http.ResponseWriter,
user *types.User,
nodeKey *key.NodePublic,
expiry time.Time,
) error {
if _, err := h.db.RegisterNodeFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules
h.registrationCache,
nodeKey.String(),
user.Name,
&expiry,
util.RegisterMethodOIDC,
); err != nil {
util.LogErr(err, "could not register node")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not register node"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return err
}
return nil
}
func renderOIDCCallbackTemplate(
writer http.ResponseWriter,
claims *IDTokenClaims,
) (*bytes.Buffer, error) {
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email,
Verb: "Authenticated",
}); err != nil {
log.Error().
Str("func", "OIDCCallback").
Str("type", "authenticate").
Err(err).
Msg("Could not render OIDC callback template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err
}
return &content, nil
}