1use std::collections::HashMap;
8
9use axum::{
10    BoxError, Json,
11    extract::{
12        Form, FromRequest, FromRequestParts,
13        rejection::{FailedToDeserializeForm, FormRejection},
14    },
15    response::IntoResponse,
16};
17use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
18use headers::{Authorization, authorization::Basic};
19use http::{Request, StatusCode};
20use mas_data_model::{Client, JwksOrJwksUri};
21use mas_http::RequestBuilderExt;
22use mas_iana::oauth::OAuthClientAuthenticationMethod;
23use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
24use mas_keystore::Encrypter;
25use mas_storage::{RepositoryAccess, oauth2::OAuth2ClientRepository};
26use oauth2_types::errors::{ClientError, ClientErrorCode};
27use serde::{Deserialize, de::DeserializeOwned};
28use serde_json::Value;
29use thiserror::Error;
30
31static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
32
33#[derive(Deserialize)]
34struct AuthorizedForm<F = ()> {
35    client_id: Option<String>,
36    client_secret: Option<String>,
37    client_assertion_type: Option<String>,
38    client_assertion: Option<String>,
39
40    #[serde(flatten)]
41    inner: F,
42}
43
44#[derive(Debug, PartialEq, Eq)]
45pub enum Credentials {
46    None {
47        client_id: String,
48    },
49    ClientSecretBasic {
50        client_id: String,
51        client_secret: String,
52    },
53    ClientSecretPost {
54        client_id: String,
55        client_secret: String,
56    },
57    ClientAssertionJwtBearer {
58        client_id: String,
59        jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
60    },
61}
62
63impl Credentials {
64    #[must_use]
66    pub fn client_id(&self) -> &str {
67        match self {
68            Credentials::None { client_id }
69            | Credentials::ClientSecretBasic { client_id, .. }
70            | Credentials::ClientSecretPost { client_id, .. }
71            | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
72        }
73    }
74
75    pub async fn fetch<E>(
82        &self,
83        repo: &mut impl RepositoryAccess<Error = E>,
84    ) -> Result<Option<Client>, E> {
85        let client_id = match self {
86            Credentials::None { client_id }
87            | Credentials::ClientSecretBasic { client_id, .. }
88            | Credentials::ClientSecretPost { client_id, .. }
89            | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
90        };
91
92        repo.oauth2_client().find_by_client_id(client_id).await
93    }
94
95    #[tracing::instrument(skip_all, err)]
101    pub async fn verify(
102        &self,
103        http_client: &reqwest::Client,
104        encrypter: &Encrypter,
105        method: &OAuthClientAuthenticationMethod,
106        client: &Client,
107    ) -> Result<(), CredentialsVerificationError> {
108        match (self, method) {
109            (Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
110
111            (
112                Credentials::ClientSecretPost { client_secret, .. },
113                OAuthClientAuthenticationMethod::ClientSecretPost,
114            )
115            | (
116                Credentials::ClientSecretBasic { client_secret, .. },
117                OAuthClientAuthenticationMethod::ClientSecretBasic,
118            ) => {
119                let encrypted_client_secret = client
121                    .encrypted_client_secret
122                    .as_ref()
123                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
124
125                let decrypted_client_secret = encrypter
126                    .decrypt_string(encrypted_client_secret)
127                    .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
128
129                if client_secret.as_bytes() != decrypted_client_secret {
131                    return Err(CredentialsVerificationError::ClientSecretMismatch);
132                }
133            }
134
135            (
136                Credentials::ClientAssertionJwtBearer { jwt, .. },
137                OAuthClientAuthenticationMethod::PrivateKeyJwt,
138            ) => {
139                let jwks = client
141                    .jwks
142                    .as_ref()
143                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
144
145                let jwks = fetch_jwks(http_client, jwks)
146                    .await
147                    .map_err(|_| CredentialsVerificationError::JwksFetchFailed)?;
148
149                jwt.verify_with_jwks(&jwks)
150                    .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
151            }
152
153            (
154                Credentials::ClientAssertionJwtBearer { jwt, .. },
155                OAuthClientAuthenticationMethod::ClientSecretJwt,
156            ) => {
157                let encrypted_client_secret = client
159                    .encrypted_client_secret
160                    .as_ref()
161                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
162
163                let decrypted_client_secret = encrypter
164                    .decrypt_string(encrypted_client_secret)
165                    .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
166
167                jwt.verify_with_shared_secret(decrypted_client_secret)
168                    .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
169            }
170
171            (_, _) => {
172                return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
173            }
174        };
175        Ok(())
176    }
177}
178
179async fn fetch_jwks(
180    http_client: &reqwest::Client,
181    jwks: &JwksOrJwksUri,
182) -> Result<PublicJsonWebKeySet, BoxError> {
183    let uri = match jwks {
184        JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
185        JwksOrJwksUri::JwksUri(u) => u,
186    };
187
188    let response = http_client
189        .get(uri.as_str())
190        .send_traced()
191        .await?
192        .error_for_status()?
193        .json()
194        .await?;
195
196    Ok(response)
197}
198
199#[derive(Debug, Error)]
200pub enum CredentialsVerificationError {
201    #[error("failed to decrypt client credentials")]
202    DecryptionError,
203
204    #[error("invalid client configuration")]
205    InvalidClientConfig,
206
207    #[error("client secret did not match")]
208    ClientSecretMismatch,
209
210    #[error("authentication method mismatch")]
211    AuthenticationMethodMismatch,
212
213    #[error("invalid assertion signature")]
214    InvalidAssertionSignature,
215
216    #[error("failed to fetch jwks")]
217    JwksFetchFailed,
218}
219
220#[derive(Debug, PartialEq, Eq)]
221pub struct ClientAuthorization<F = ()> {
222    pub credentials: Credentials,
223    pub form: Option<F>,
224}
225
226impl<F> ClientAuthorization<F> {
227    #[must_use]
229    pub fn client_id(&self) -> &str {
230        self.credentials.client_id()
231    }
232}
233
234#[derive(Debug)]
235pub enum ClientAuthorizationError {
236    InvalidHeader,
237    BadForm(FailedToDeserializeForm),
238    ClientIdMismatch { credential: String, form: String },
239    UnsupportedClientAssertion { client_assertion_type: String },
240    MissingCredentials,
241    InvalidRequest,
242    InvalidAssertion,
243    Internal(Box<dyn std::error::Error>),
244}
245
246impl IntoResponse for ClientAuthorizationError {
247    fn into_response(self) -> axum::response::Response {
248        match self {
249            ClientAuthorizationError::InvalidHeader => (
250                StatusCode::BAD_REQUEST,
251                Json(ClientError::new(
252                    ClientErrorCode::InvalidRequest,
253                    "Invalid Authorization header",
254                )),
255            ),
256
257            ClientAuthorizationError::BadForm(err) => (
258                StatusCode::BAD_REQUEST,
259                Json(
260                    ClientError::from(ClientErrorCode::InvalidRequest)
261                        .with_description(format!("{err}")),
262                ),
263            ),
264
265            ClientAuthorizationError::ClientIdMismatch { form, credential } => {
266                let description = format!(
267                    "client_id in form ({form:?}) does not match credential ({credential:?})"
268                );
269
270                (
271                    StatusCode::BAD_REQUEST,
272                    Json(
273                        ClientError::from(ClientErrorCode::InvalidGrant)
274                            .with_description(description),
275                    ),
276                )
277            }
278
279            ClientAuthorizationError::UnsupportedClientAssertion {
280                client_assertion_type,
281            } => (
282                StatusCode::BAD_REQUEST,
283                Json(
284                    ClientError::from(ClientErrorCode::InvalidRequest).with_description(format!(
285                        "Unsupported client_assertion_type: {client_assertion_type}",
286                    )),
287                ),
288            ),
289
290            ClientAuthorizationError::MissingCredentials => (
291                StatusCode::BAD_REQUEST,
292                Json(ClientError::new(
293                    ClientErrorCode::InvalidRequest,
294                    "No credentials were presented",
295                )),
296            ),
297
298            ClientAuthorizationError::InvalidRequest => (
299                StatusCode::BAD_REQUEST,
300                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
301            ),
302
303            ClientAuthorizationError::InvalidAssertion => (
304                StatusCode::BAD_REQUEST,
305                Json(ClientError::new(
306                    ClientErrorCode::InvalidRequest,
307                    "Invalid client_assertion",
308                )),
309            ),
310
311            ClientAuthorizationError::Internal(e) => (
312                StatusCode::INTERNAL_SERVER_ERROR,
313                Json(
314                    ClientError::from(ClientErrorCode::ServerError)
315                        .with_description(format!("{e}")),
316                ),
317            ),
318        }
319        .into_response()
320    }
321}
322
323impl<S, F> FromRequest<S> for ClientAuthorization<F>
324where
325    F: DeserializeOwned,
326    S: Send + Sync,
327{
328    type Rejection = ClientAuthorizationError;
329
330    #[allow(clippy::too_many_lines)]
331    async fn from_request(
332        req: Request<axum::body::Body>,
333        state: &S,
334    ) -> Result<Self, Self::Rejection> {
335        let (mut parts, body) = req.into_parts();
337
338        let header =
339            TypedHeader::<Authorization<Basic>>::from_request_parts(&mut parts, state).await;
340
341        let credentials_from_header = match header {
343            Ok(header) => Some((header.username().to_owned(), header.password().to_owned())),
344            Err(err) => match err.reason() {
345                TypedHeaderRejectionReason::Missing => None,
347                _ => return Err(ClientAuthorizationError::InvalidHeader),
349            },
350        };
351
352        let req = Request::from_parts(parts, body);
354
355        let (
357            client_id_from_form,
358            client_secret_from_form,
359            client_assertion_type,
360            client_assertion,
361            form,
362        ) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
363            Ok(Form(form)) => (
364                form.client_id,
365                form.client_secret,
366                form.client_assertion_type,
367                form.client_assertion,
368                Some(form.inner),
369            ),
370            Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
372            Err(FormRejection::FailedToDeserializeForm(err)) => {
374                return Err(ClientAuthorizationError::BadForm(err));
375            }
376            Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))),
378        };
379
380        let credentials = match (
382            credentials_from_header,
383            client_id_from_form,
384            client_secret_from_form,
385            client_assertion_type,
386            client_assertion,
387        ) {
388            (Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
389                if let Some(client_id_from_form) = client_id_from_form {
390                    if client_id != client_id_from_form {
392                        return Err(ClientAuthorizationError::ClientIdMismatch {
393                            credential: client_id,
394                            form: client_id_from_form,
395                        });
396                    }
397                }
398
399                Credentials::ClientSecretBasic {
400                    client_id,
401                    client_secret,
402                }
403            }
404
405            (None, Some(client_id), Some(client_secret), None, None) => {
406                Credentials::ClientSecretPost {
408                    client_id,
409                    client_secret,
410                }
411            }
412
413            (None, Some(client_id), None, None, None) => {
414                Credentials::None { client_id }
416            }
417
418            (
419                None,
420                client_id_from_form,
421                None,
422                Some(client_assertion_type),
423                Some(client_assertion),
424            ) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
425                let jwt: Jwt<'static, HashMap<String, Value>> = Jwt::try_from(client_assertion)
427                    .map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
428
429                let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") {
430                    client_id.clone()
431                } else {
432                    return Err(ClientAuthorizationError::InvalidAssertion);
433                };
434
435                if let Some(client_id_from_form) = client_id_from_form {
436                    if client_id != client_id_from_form {
438                        return Err(ClientAuthorizationError::ClientIdMismatch {
439                            credential: client_id,
440                            form: client_id_from_form,
441                        });
442                    }
443                }
444
445                Credentials::ClientAssertionJwtBearer {
446                    client_id,
447                    jwt: Box::new(jwt),
448                }
449            }
450
451            (None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
452                return Err(ClientAuthorizationError::UnsupportedClientAssertion {
454                    client_assertion_type,
455                });
456            }
457
458            (None, None, None, None, None) => {
459                return Err(ClientAuthorizationError::MissingCredentials);
461            }
462
463            _ => {
464                return Err(ClientAuthorizationError::InvalidRequest);
466            }
467        };
468
469        Ok(ClientAuthorization { credentials, form })
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use axum::body::Body;
476    use http::{Method, Request};
477
478    use super::*;
479
480    #[tokio::test]
481    async fn none_test() {
482        let req = Request::builder()
483            .method(Method::POST)
484            .header(
485                http::header::CONTENT_TYPE,
486                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
487            )
488            .body(Body::new("client_id=client-id&foo=bar".to_owned()))
489            .unwrap();
490
491        assert_eq!(
492            ClientAuthorization::<serde_json::Value>::from_request(req, &())
493                .await
494                .unwrap(),
495            ClientAuthorization {
496                credentials: Credentials::None {
497                    client_id: "client-id".to_owned(),
498                },
499                form: Some(serde_json::json!({"foo": "bar"})),
500            }
501        );
502    }
503
504    #[tokio::test]
505    async fn client_secret_basic_test() {
506        let req = Request::builder()
507            .method(Method::POST)
508            .header(
509                http::header::CONTENT_TYPE,
510                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
511            )
512            .header(
513                http::header::AUTHORIZATION,
514                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
515            )
516            .body(Body::new("foo=bar".to_owned()))
517            .unwrap();
518
519        assert_eq!(
520            ClientAuthorization::<serde_json::Value>::from_request(req, &())
521                .await
522                .unwrap(),
523            ClientAuthorization {
524                credentials: Credentials::ClientSecretBasic {
525                    client_id: "client-id".to_owned(),
526                    client_secret: "client-secret".to_owned(),
527                },
528                form: Some(serde_json::json!({"foo": "bar"})),
529            }
530        );
531
532        let req = Request::builder()
534            .method(Method::POST)
535            .header(
536                http::header::CONTENT_TYPE,
537                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
538            )
539            .header(
540                http::header::AUTHORIZATION,
541                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
542            )
543            .body(Body::new("client_id=client-id&foo=bar".to_owned()))
544            .unwrap();
545
546        assert_eq!(
547            ClientAuthorization::<serde_json::Value>::from_request(req, &())
548                .await
549                .unwrap(),
550            ClientAuthorization {
551                credentials: Credentials::ClientSecretBasic {
552                    client_id: "client-id".to_owned(),
553                    client_secret: "client-secret".to_owned(),
554                },
555                form: Some(serde_json::json!({"foo": "bar"})),
556            }
557        );
558
559        let req = Request::builder()
561            .method(Method::POST)
562            .header(
563                http::header::CONTENT_TYPE,
564                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
565            )
566            .header(
567                http::header::AUTHORIZATION,
568                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
569            )
570            .body(Body::new("client_id=mismatch-id&foo=bar".to_owned()))
571            .unwrap();
572
573        assert!(matches!(
574            ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
575            Err(ClientAuthorizationError::ClientIdMismatch { .. }),
576        ));
577
578        let req = Request::builder()
580            .method(Method::POST)
581            .header(
582                http::header::CONTENT_TYPE,
583                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
584            )
585            .header(http::header::AUTHORIZATION, "Basic invalid")
586            .body(Body::new("foo=bar".to_owned()))
587            .unwrap();
588
589        assert!(matches!(
590            ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
591            Err(ClientAuthorizationError::InvalidHeader),
592        ));
593    }
594
595    #[tokio::test]
596    async fn client_secret_post_test() {
597        let req = Request::builder()
598            .method(Method::POST)
599            .header(
600                http::header::CONTENT_TYPE,
601                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
602            )
603            .body(Body::new(
604                "client_id=client-id&client_secret=client-secret&foo=bar".to_owned(),
605            ))
606            .unwrap();
607
608        assert_eq!(
609            ClientAuthorization::<serde_json::Value>::from_request(req, &())
610                .await
611                .unwrap(),
612            ClientAuthorization {
613                credentials: Credentials::ClientSecretPost {
614                    client_id: "client-id".to_owned(),
615                    client_secret: "client-secret".to_owned(),
616                },
617                form: Some(serde_json::json!({"foo": "bar"})),
618            }
619        );
620    }
621
622    #[tokio::test]
623    async fn client_assertion_test() {
624        let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
626        let body = Body::new(format!(
627            "client_assertion_type={JWT_BEARER_CLIENT_ASSERTION}&client_assertion={jwt}&foo=bar",
628        ));
629
630        let req = Request::builder()
631            .method(Method::POST)
632            .header(
633                http::header::CONTENT_TYPE,
634                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
635            )
636            .body(body)
637            .unwrap();
638
639        let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
640            .await
641            .unwrap();
642        assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
643
644        let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials else {
645            panic!("expected a JWT client_assertion");
646        };
647
648        assert_eq!(client_id, "client-id");
649        jwt.verify_with_shared_secret(b"client-secret".to_vec())
650            .unwrap();
651    }
652}