From a1e7e771cecfbcd988c340e1e8fd1d6cd5a5d467 Mon Sep 17 00:00:00 2001 From: Grigoriy Mikhalkin Date: Sun, 7 Aug 2022 13:57:07 +0200 Subject: [PATCH] refactor OIDC callback aux functions --- oidc.go | 185 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 96 insertions(+), 89 deletions(-) diff --git a/oidc.go b/oidc.go index 5509bd4..a385a92 100644 --- a/oidc.go +++ b/oidc.go @@ -21,6 +21,13 @@ import ( const ( randomByteSize = 16 + + errEmptyOIDCCallbackParams = Error("empty OIDC callback params") + errNoOIDCIDToken = Error("could not extract ID Token for OIDC callback") + errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain") + errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user") + errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed") + errOIDCMachineKeyMissing = Error("could not get machine key from cache") ) type IDTokenClaims struct { @@ -136,18 +143,18 @@ func (h *Headscale) OIDCCallback( writer http.ResponseWriter, req *http.Request, ) { - code, state, ok := validateOIDCCallbackParams(writer, req) - if !ok { + code, state, err := validateOIDCCallbackParams(writer, req) + if err != nil { return } - rawIDToken, ok := h.getIDTokenForOIDCCallback(writer, code, state) - if !ok { + rawIDToken, err := h.getIDTokenForOIDCCallback(writer, code, state) + if err != nil { return } - idToken, ok := h.verifyIDTokenForOIDCCallback(writer, rawIDToken) - if !ok { + idToken, err := h.verifyIDTokenForOIDCCallback(writer, rawIDToken) + if err != nil { return } @@ -158,43 +165,43 @@ func (h *Headscale) OIDCCallback( // return // } - claims, ok := extractIDTokenClaims(writer, idToken) - if !ok { + claims, err := extractIDTokenClaims(writer, idToken) + if err != nil { return } - if ok := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); !ok { + if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil { return } - if ok := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); !ok { + if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil { return } - machineKey, ok := h.validateMachineForOIDCCallback(writer, state, claims) - if !ok { + machineKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims) + if err != nil || machineExists { return } - namespaceName, ok := getNamespaceName(writer, claims, h.cfg.OIDC.StripEmaildomain) - if !ok { + namespaceName, err := getNamespaceName(writer, claims, h.cfg.OIDC.StripEmaildomain) + if err != nil { return } // register the machine if it's new log.Debug().Msg("Registering new machine after successful callback") - namespace, ok := h.findOrCreateNewNamespaceForOIDCCallback(writer, namespaceName) - if !ok { + namespace, err := h.findOrCreateNewNamespaceForOIDCCallback(writer, namespaceName) + if err != nil { return } - if ok := h.registerMachineForOIDCCallback(writer, namespace, machineKey); !ok { + if err := h.registerMachineForOIDCCallback(writer, namespace, machineKey); err != nil { return } - content, ok := renderOIDCCallbackTemplate(writer, claims) - if !ok { + content, err := renderOIDCCallbackTemplate(writer, claims) + if err != nil { return } @@ -211,7 +218,7 @@ func (h *Headscale) OIDCCallback( func validateOIDCCallbackParams( writer http.ResponseWriter, req *http.Request, -) (string, string, bool) { +) (string, string, error) { code := req.URL.Query().Get("code") state := req.URL.Query().Get("state") @@ -226,16 +233,16 @@ func validateOIDCCallbackParams( Msg("Failed to write response") } - return "", "", false + return "", "", errEmptyOIDCCallbackParams } - return code, state, true + return code, state, nil } func (h *Headscale) getIDTokenForOIDCCallback( writer http.ResponseWriter, code, state string, -) (string, bool) { +) (string, error) { oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) if err != nil { log.Error(). @@ -244,15 +251,15 @@ func (h *Headscale) getIDTokenForOIDCCallback( Msg("Could not exchange code for token") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("Could not exchange code for token")) - if err != nil { + _, werr := writer.Write([]byte("Could not exchange code for token")) + if werr != nil { log.Error(). Caller(). - Err(err). + Err(werr). Msg("Failed to write response") } - return "", false + return "", err } log.Trace(). @@ -273,16 +280,16 @@ func (h *Headscale) getIDTokenForOIDCCallback( Msg("Failed to write response") } - return "", false + return "", errNoOIDCIDToken } - return rawIDToken, true + return rawIDToken, nil } func (h *Headscale) verifyIDTokenForOIDCCallback( writer http.ResponseWriter, rawIDToken string, -) (*oidc.IDToken, bool) { +) (*oidc.IDToken, error) { verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) idToken, err := verifier.Verify(context.Background(), rawIDToken) if err != nil { @@ -292,24 +299,24 @@ func (h *Headscale) verifyIDTokenForOIDCCallback( Msg("failed to verify id token") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("Failed to verify id token")) - if err != nil { + _, werr := writer.Write([]byte("Failed to verify id token")) + if werr != nil { log.Error(). Caller(). - Err(err). + Err(werr). Msg("Failed to write response") } - return nil, false + return nil, err } - return idToken, true + return idToken, nil } func extractIDTokenClaims( writer http.ResponseWriter, idToken *oidc.IDToken, -) (*IDTokenClaims, bool) { +) (*IDTokenClaims, error) { var claims IDTokenClaims if err := idToken.Claims(claims); err != nil { log.Error(). @@ -318,18 +325,18 @@ func extractIDTokenClaims( Msg("Failed to decode id token claims") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("Failed to decode id token claims")) - if err != nil { + _, werr := writer.Write([]byte("Failed to decode id token claims")) + if werr != nil { log.Error(). Caller(). - Err(err). + Err(werr). Msg("Failed to write response") } - return nil, false + return nil, err } - return &claims, true + return &claims, nil } // validateOIDCAllowedDomains checks that if AllowedDomains is provided, @@ -338,7 +345,7 @@ func validateOIDCAllowedDomains( writer http.ResponseWriter, allowedDomains []string, claims *IDTokenClaims, -) bool { +) error { if len(allowedDomains) > 0 { if at := strings.LastIndex(claims.Email, "@"); at < 0 || !IsStringInSlice(allowedDomains, claims.Email[at+1:]) { @@ -353,11 +360,11 @@ func validateOIDCAllowedDomains( Msg("Failed to write response") } - return false + return errOIDCAllowedDomains } } - return true + return nil } // validateOIDCAllowedUsers checks that if AllowedUsers is provided, @@ -366,7 +373,7 @@ func validateOIDCAllowedUsers( writer http.ResponseWriter, allowedUsers []string, claims *IDTokenClaims, -) bool { +) error { if len(allowedUsers) > 0 && !IsStringInSlice(allowedUsers, claims.Email) { log.Error().Msg("authenticated principal does not match any allowed user") @@ -380,10 +387,10 @@ func validateOIDCAllowedUsers( Msg("Failed to write response") } - return false + return errOIDCAllowedUsers } - return true + return nil } // validateMachine retrieves machine information if it exist @@ -394,7 +401,7 @@ func (h *Headscale) validateMachineForOIDCCallback( writer http.ResponseWriter, state string, claims *IDTokenClaims, -) (*key.MachinePublic, bool) { +) (*key.MachinePublic, bool, error) { // retrieve machinekey from state cache machineKeyIf, machineKeyFound := h.registrationCache.Get(state) if !machineKeyFound { @@ -410,7 +417,7 @@ func (h *Headscale) validateMachineForOIDCCallback( Msg("Failed to write response") } - return nil, false + return nil, false, errOIDCInvalidMachineState } var machineKey key.MachinePublic @@ -423,15 +430,15 @@ func (h *Headscale) validateMachineForOIDCCallback( Msg("could not parse machine public key") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("could not parse public key")) - if err != nil { + _, werr := writer.Write([]byte("could not parse public key")) + if werr != nil { log.Error(). Caller(). - Err(err). + Err(werr). Msg("Failed to write response") } - return nil, false + return nil, false, err } if !machineKeyOK { @@ -446,7 +453,7 @@ func (h *Headscale) validateMachineForOIDCCallback( Msg("Failed to write response") } - return nil, false + return nil, false, errOIDCMachineKeyMissing } // retrieve machine information if it exist @@ -469,7 +476,7 @@ func (h *Headscale) validateMachineForOIDCCallback( Msg("Failed to refresh machine") http.Error(writer, "Failed to refresh machine", http.StatusInternalServerError) - return nil, false + return nil, true, err } var content bytes.Buffer @@ -485,15 +492,15 @@ func (h *Headscale) validateMachineForOIDCCallback( writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Could not render OIDC callback template")) - if err != nil { + _, werr := writer.Write([]byte("Could not render OIDC callback template")) + if werr != nil { log.Error(). Caller(). - Err(err). + Err(werr). Msg("Failed to write response") } - return nil, false + return nil, true, err } writer.Header().Set("Content-Type", "text/html; charset=utf-8") @@ -506,17 +513,17 @@ func (h *Headscale) validateMachineForOIDCCallback( Msg("Failed to write response") } - return nil, false + return nil, true, nil } - return &machineKey, true + return &machineKey, false, nil } func getNamespaceName( writer http.ResponseWriter, claims *IDTokenClaims, stripEmaildomain bool, -) (string, bool) { +) (string, error) { namespaceName, err := NormalizeToFQDNRules( claims.Email, stripEmaildomain, @@ -525,24 +532,24 @@ func getNamespaceName( log.Error().Err(err).Caller().Msgf("couldn't normalize email") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("couldn't normalize email")) - if err != nil { + _, werr := writer.Write([]byte("couldn't normalize email")) + if werr != nil { log.Error(). Caller(). - Err(err). + Err(werr). Msg("Failed to write response") } - return "", false + return "", err } - return namespaceName, true + return namespaceName, nil } func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback( writer http.ResponseWriter, namespaceName string, -) (*Namespace, bool) { +) (*Namespace, error) { namespace, err := h.GetNamespace(namespaceName) if errors.Is(err, errNamespaceNotFound) { namespace, err = h.CreateNamespace(namespaceName) @@ -554,15 +561,15 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback( Msgf("could not create new namespace '%s'", namespaceName) writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("could not create namespace")) - if err != nil { + _, werr := writer.Write([]byte("could not create namespace")) + if werr != nil { log.Error(). Caller(). - Err(err). + Err(werr). Msg("Failed to write response") } - return nil, false + return nil, err } } else if err != nil { log.Error(). @@ -572,25 +579,25 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback( Msg("could not find or create namespace") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("could not find or create namespace")) - if err != nil { + _, werr := writer.Write([]byte("could not find or create namespace")) + if werr != nil { log.Error(). Caller(). - Err(err). + Err(werr). Msg("Failed to write response") } - return nil, false + return nil, err } - return namespace, true + return namespace, nil } func (h *Headscale) registerMachineForOIDCCallback( writer http.ResponseWriter, namespace *Namespace, machineKey *key.MachinePublic, -) bool { +) error { machineKeyStr := MachinePublicKeyStripPrefix(*machineKey) if _, err := h.RegisterMachineFromAuthCallback( @@ -604,24 +611,24 @@ func (h *Headscale) registerMachineForOIDCCallback( Msg("could not register machine") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("could not register machine")) - if err != nil { + _, werr := writer.Write([]byte("could not register machine")) + if werr != nil { log.Error(). Caller(). - Err(err). + Err(werr). Msg("Failed to write response") } - return false + return err } - return true + return nil } func renderOIDCCallbackTemplate( writer http.ResponseWriter, claims *IDTokenClaims, -) (*bytes.Buffer, bool) { +) (*bytes.Buffer, error) { var content bytes.Buffer if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ User: claims.Email, @@ -635,16 +642,16 @@ func renderOIDCCallbackTemplate( writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Could not render OIDC callback template")) - if err != nil { + _, werr := writer.Write([]byte("Could not render OIDC callback template")) + if werr != nil { log.Error(). Caller(). - Err(err). + Err(werr). Msg("Failed to write response") } - return nil, false + return nil, err } - return &content, true + return &content, nil }