1use std::num::NonZeroU32;
8
9use chrono::{DateTime, Duration, Utc};
10use mas_iana::oauth::PkceCodeChallengeMethod;
11use oauth2_types::{
12    pkce::{CodeChallengeError, CodeChallengeMethodExt},
13    requests::ResponseMode,
14    scope::{OPENID, PROFILE, Scope},
15};
16use rand::{
17    RngCore,
18    distributions::{Alphanumeric, DistString},
19};
20use ruma_common::UserId;
21use serde::Serialize;
22use ulid::Ulid;
23use url::Url;
24
25use super::session::Session;
26use crate::InvalidTransitionError;
27
28#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
29pub struct Pkce {
30    pub challenge_method: PkceCodeChallengeMethod,
31    pub challenge: String,
32}
33
34impl Pkce {
35    #[must_use]
37    pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self {
38        Pkce {
39            challenge_method,
40            challenge,
41        }
42    }
43
44    pub fn verify(&self, verifier: &str) -> Result<(), CodeChallengeError> {
50        self.challenge_method.verify(&self.challenge, verifier)
51    }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
55pub struct AuthorizationCode {
56    pub code: String,
57    pub pkce: Option<Pkce>,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
61#[serde(tag = "stage", rename_all = "lowercase")]
62pub enum AuthorizationGrantStage {
63    #[default]
64    Pending,
65    Fulfilled {
66        session_id: Ulid,
67        fulfilled_at: DateTime<Utc>,
68    },
69    Exchanged {
70        session_id: Ulid,
71        fulfilled_at: DateTime<Utc>,
72        exchanged_at: DateTime<Utc>,
73    },
74    Cancelled {
75        cancelled_at: DateTime<Utc>,
76    },
77}
78
79impl AuthorizationGrantStage {
80    #[must_use]
81    pub fn new() -> Self {
82        Self::Pending
83    }
84
85    fn fulfill(
86        self,
87        fulfilled_at: DateTime<Utc>,
88        session: &Session,
89    ) -> Result<Self, InvalidTransitionError> {
90        match self {
91            Self::Pending => Ok(Self::Fulfilled {
92                fulfilled_at,
93                session_id: session.id,
94            }),
95            _ => Err(InvalidTransitionError),
96        }
97    }
98
99    fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
100        match self {
101            Self::Fulfilled {
102                fulfilled_at,
103                session_id,
104            } => Ok(Self::Exchanged {
105                fulfilled_at,
106                exchanged_at,
107                session_id,
108            }),
109            _ => Err(InvalidTransitionError),
110        }
111    }
112
113    fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
114        match self {
115            Self::Pending => Ok(Self::Cancelled { cancelled_at }),
116            _ => Err(InvalidTransitionError),
117        }
118    }
119
120    #[must_use]
124    pub fn is_pending(&self) -> bool {
125        matches!(self, Self::Pending)
126    }
127
128    #[must_use]
132    pub fn is_fulfilled(&self) -> bool {
133        matches!(self, Self::Fulfilled { .. })
134    }
135
136    #[must_use]
140    pub fn is_exchanged(&self) -> bool {
141        matches!(self, Self::Exchanged { .. })
142    }
143}
144
145pub enum LoginHint<'a> {
146    MXID(&'a UserId),
147    None,
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
151pub struct AuthorizationGrant {
152    pub id: Ulid,
153    #[serde(flatten)]
154    pub stage: AuthorizationGrantStage,
155    pub code: Option<AuthorizationCode>,
156    pub client_id: Ulid,
157    pub redirect_uri: Url,
158    pub scope: Scope,
159    pub state: Option<String>,
160    pub nonce: Option<String>,
161    pub max_age: Option<NonZeroU32>,
162    pub response_mode: ResponseMode,
163    pub response_type_id_token: bool,
164    pub created_at: DateTime<Utc>,
165    pub requires_consent: bool,
166    pub login_hint: Option<String>,
167}
168
169impl std::ops::Deref for AuthorizationGrant {
170    type Target = AuthorizationGrantStage;
171
172    fn deref(&self) -> &Self::Target {
173        &self.stage
174    }
175}
176
177const DEFAULT_MAX_AGE: Duration = Duration::microseconds(3600 * 24 * 365 * 1000 * 1000);
178
179impl AuthorizationGrant {
180    #[must_use]
181    pub fn max_auth_time(&self) -> DateTime<Utc> {
182        let max_age = self
183            .max_age
184            .and_then(|x| Duration::try_seconds(x.get().into()))
185            .unwrap_or(DEFAULT_MAX_AGE);
186        self.created_at - max_age
187    }
188
189    #[must_use]
190    pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint {
191        let Some(login_hint) = &self.login_hint else {
192            return LoginHint::None;
193        };
194
195        let Some((prefix, value)) = login_hint.split_once(':') else {
197            return LoginHint::None;
198        };
199
200        match prefix {
201            "mxid" => {
202                let Ok(mxid) = <&UserId>::try_from(value) else {
204                    return LoginHint::None;
205                };
206
207                if mxid.server_name() != homeserver {
209                    return LoginHint::None;
210                }
211
212                LoginHint::MXID(mxid)
213            }
214            _ => LoginHint::None,
216        }
217    }
218
219    pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
227        self.stage = self.stage.exchange(exchanged_at)?;
228        Ok(self)
229    }
230
231    pub fn fulfill(
239        mut self,
240        fulfilled_at: DateTime<Utc>,
241        session: &Session,
242    ) -> Result<Self, InvalidTransitionError> {
243        self.stage = self.stage.fulfill(fulfilled_at, session)?;
244        Ok(self)
245    }
246
247    pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
259        self.stage = self.stage.cancel(canceld_at)?;
260        Ok(self)
261    }
262
263    #[doc(hidden)]
264    pub fn sample(now: DateTime<Utc>, rng: &mut impl RngCore) -> Self {
265        Self {
266            id: Ulid::from_datetime_with_source(now.into(), rng),
267            stage: AuthorizationGrantStage::Pending,
268            code: Some(AuthorizationCode {
269                code: Alphanumeric.sample_string(rng, 10),
270                pkce: None,
271            }),
272            client_id: Ulid::from_datetime_with_source(now.into(), rng),
273            redirect_uri: Url::parse("http://localhost:8080").unwrap(),
274            scope: Scope::from_iter([OPENID, PROFILE]),
275            state: Some(Alphanumeric.sample_string(rng, 10)),
276            nonce: Some(Alphanumeric.sample_string(rng, 10)),
277            max_age: None,
278            response_mode: ResponseMode::Query,
279            response_type_id_token: false,
280            created_at: now,
281            requires_consent: false,
282            login_hint: Some(String::from("mxid:@example-user:example.com")),
283        }
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use rand::thread_rng;
290
291    use super::*;
292
293    #[test]
294    fn no_login_hint() {
295        #[allow(clippy::disallowed_methods)]
296        let mut rng = thread_rng();
297
298        #[allow(clippy::disallowed_methods)]
299        let now = Utc::now();
300
301        let grant = AuthorizationGrant {
302            login_hint: None,
303            ..AuthorizationGrant::sample(now, &mut rng)
304        };
305
306        let hint = grant.parse_login_hint("example.com");
307
308        assert!(matches!(hint, LoginHint::None));
309    }
310
311    #[test]
312    fn valid_login_hint() {
313        #[allow(clippy::disallowed_methods)]
314        let mut rng = thread_rng();
315
316        #[allow(clippy::disallowed_methods)]
317        let now = Utc::now();
318
319        let grant = AuthorizationGrant {
320            login_hint: Some(String::from("mxid:@example-user:example.com")),
321            ..AuthorizationGrant::sample(now, &mut rng)
322        };
323
324        let hint = grant.parse_login_hint("example.com");
325
326        assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user"));
327    }
328
329    #[test]
330    fn invalid_login_hint() {
331        #[allow(clippy::disallowed_methods)]
332        let mut rng = thread_rng();
333
334        #[allow(clippy::disallowed_methods)]
335        let now = Utc::now();
336
337        let grant = AuthorizationGrant {
338            login_hint: Some(String::from("example-user")),
339            ..AuthorizationGrant::sample(now, &mut rng)
340        };
341
342        let hint = grant.parse_login_hint("example.com");
343
344        assert!(matches!(hint, LoginHint::None));
345    }
346
347    #[test]
348    fn valid_login_hint_for_wrong_homeserver() {
349        #[allow(clippy::disallowed_methods)]
350        let mut rng = thread_rng();
351
352        #[allow(clippy::disallowed_methods)]
353        let now = Utc::now();
354
355        let grant = AuthorizationGrant {
356            login_hint: Some(String::from("mxid:@example-user:matrix.org")),
357            ..AuthorizationGrant::sample(now, &mut rng)
358        };
359
360        let hint = grant.parse_login_hint("example.com");
361
362        assert!(matches!(hint, LoginHint::None));
363    }
364
365    #[test]
366    fn unknown_login_hint_type() {
367        #[allow(clippy::disallowed_methods)]
368        let mut rng = thread_rng();
369
370        #[allow(clippy::disallowed_methods)]
371        let now = Utc::now();
372
373        let grant = AuthorizationGrant {
374            login_hint: Some(String::from("something:anything")),
375            ..AuthorizationGrant::sample(now, &mut rng)
376        };
377
378        let hint = grant.parse_login_hint("example.com");
379
380        assert!(matches!(hint, LoginHint::None));
381    }
382}