mas_storage/upstream_oauth2/provider.rs
1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::marker::PhantomData;
8
9use async_trait::async_trait;
10use mas_data_model::{
11    UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
12    UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode,
13    UpstreamOAuthProviderTokenAuthMethod,
14};
15use mas_iana::jose::JsonWebSignatureAlg;
16use oauth2_types::scope::Scope;
17use rand_core::RngCore;
18use ulid::Ulid;
19use url::Url;
20
21use crate::{Clock, Pagination, pagination::Page, repository_impl};
22
23/// Structure which holds parameters when inserting or updating an upstream
24/// OAuth 2.0 provider
25pub struct UpstreamOAuthProviderParams {
26    /// The OIDC issuer of the provider
27    pub issuer: Option<String>,
28
29    /// A human-readable name for the provider
30    pub human_name: Option<String>,
31
32    /// A brand identifier, e.g. "apple" or "google"
33    pub brand_name: Option<String>,
34
35    /// The scope to request during the authorization flow
36    pub scope: Scope,
37
38    /// The token endpoint authentication method
39    pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod,
40
41    /// The JWT signing algorithm to use when then `client_secret_jwt` or
42    /// `private_key_jwt` authentication methods are used
43    pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
44
45    /// Expected signature for the JWT payload returned by the token
46    /// authentication endpoint.
47    ///
48    /// Defaults to `RS256`.
49    pub id_token_signed_response_alg: JsonWebSignatureAlg,
50
51    /// Whether to fetch the user profile from the userinfo endpoint,
52    /// or to rely on the data returned in the `id_token` from the
53    /// `token_endpoint`.
54    pub fetch_userinfo: bool,
55
56    /// Expected signature for the JWT payload returned by the userinfo
57    /// endpoint.
58    ///
59    /// If not specified, the response is expected to be an unsigned JSON
60    /// payload. Defaults to `None`.
61    pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
62
63    /// The client ID to use when authenticating to the upstream
64    pub client_id: String,
65
66    /// The encrypted client secret to use when authenticating to the upstream
67    pub encrypted_client_secret: Option<String>,
68
69    /// How claims should be imported from the upstream provider
70    pub claims_imports: UpstreamOAuthProviderClaimsImports,
71
72    /// The URL to use as the authorization endpoint. If `None`, the URL will be
73    /// discovered
74    pub authorization_endpoint_override: Option<Url>,
75
76    /// The URL to use as the token endpoint. If `None`, the URL will be
77    /// discovered
78    pub token_endpoint_override: Option<Url>,
79
80    /// The URL to use as the userinfo endpoint. If `None`, the URL will be
81    /// discovered
82    pub userinfo_endpoint_override: Option<Url>,
83
84    /// The URL to use when fetching JWKS. If `None`, the URL will be discovered
85    pub jwks_uri_override: Option<Url>,
86
87    /// How the provider metadata should be discovered
88    pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
89
90    /// How should PKCE be used
91    pub pkce_mode: UpstreamOAuthProviderPkceMode,
92
93    /// What response mode it should ask
94    pub response_mode: Option<UpstreamOAuthProviderResponseMode>,
95
96    /// Additional parameters to include in the authorization request
97    pub additional_authorization_parameters: Vec<(String, String)>,
98
99    /// Whether to forward the login hint to the upstream provider.
100    pub forward_login_hint: bool,
101
102    /// The position of the provider in the UI
103    pub ui_order: i32,
104}
105
106/// Filter parameters for listing upstream OAuth 2.0 providers
107#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
108pub struct UpstreamOAuthProviderFilter<'a> {
109    /// Filter by whether the provider is enabled
110    ///
111    /// If `None`, all providers are returned
112    enabled: Option<bool>,
113
114    _lifetime: PhantomData<&'a ()>,
115}
116
117impl UpstreamOAuthProviderFilter<'_> {
118    /// Create a new [`UpstreamOAuthProviderFilter`] with default values
119    #[must_use]
120    pub fn new() -> Self {
121        Self::default()
122    }
123
124    /// Return only enabled providers
125    #[must_use]
126    pub const fn enabled_only(mut self) -> Self {
127        self.enabled = Some(true);
128        self
129    }
130
131    /// Return only disabled providers
132    #[must_use]
133    pub const fn disabled_only(mut self) -> Self {
134        self.enabled = Some(false);
135        self
136    }
137
138    /// Get the enabled filter
139    ///
140    /// Returns `None` if the filter is not set
141    #[must_use]
142    pub const fn enabled(&self) -> Option<bool> {
143        self.enabled
144    }
145}
146
147/// An [`UpstreamOAuthProviderRepository`] helps interacting with
148/// [`UpstreamOAuthProvider`] saved in the storage backend
149#[async_trait]
150pub trait UpstreamOAuthProviderRepository: Send + Sync {
151    /// The error type returned by the repository
152    type Error;
153
154    /// Lookup an upstream OAuth provider by its ID
155    ///
156    /// Returns `None` if the provider was not found
157    ///
158    /// # Parameters
159    ///
160    /// * `id`: The ID of the provider to lookup
161    ///
162    /// # Errors
163    ///
164    /// Returns [`Self::Error`] if the underlying repository fails
165    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
166
167    /// Add a new upstream OAuth provider
168    ///
169    /// Returns the newly created provider
170    ///
171    /// # Parameters
172    ///
173    /// * `rng`: A random number generator
174    /// * `clock`: The clock used to generate timestamps
175    /// * `params`: The parameters of the provider to add
176    ///
177    /// # Errors
178    ///
179    /// Returns [`Self::Error`] if the underlying repository fails
180    async fn add(
181        &mut self,
182        rng: &mut (dyn RngCore + Send),
183        clock: &dyn Clock,
184        params: UpstreamOAuthProviderParams,
185    ) -> Result<UpstreamOAuthProvider, Self::Error>;
186
187    /// Delete an upstream OAuth provider
188    ///
189    /// # Parameters
190    ///
191    /// * `provider`: The provider to delete
192    ///
193    /// # Errors
194    ///
195    /// Returns [`Self::Error`] if the underlying repository fails
196    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
197        self.delete_by_id(provider.id).await
198    }
199
200    /// Delete an upstream OAuth provider by its ID
201    ///
202    /// # Parameters
203    ///
204    /// * `id`: The ID of the provider to delete
205    ///
206    /// # Errors
207    ///
208    /// Returns [`Self::Error`] if the underlying repository fails
209    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
210
211    /// Insert or update an upstream OAuth provider
212    ///
213    /// # Parameters
214    ///
215    /// * `clock`: The clock used to generate timestamps
216    /// * `id`: The ID of the provider to update
217    /// * `params`: The parameters of the provider to update
218    ///
219    /// # Errors
220    ///
221    /// Returns [`Self::Error`] if the underlying repository fails
222    async fn upsert(
223        &mut self,
224        clock: &dyn Clock,
225        id: Ulid,
226        params: UpstreamOAuthProviderParams,
227    ) -> Result<UpstreamOAuthProvider, Self::Error>;
228
229    /// Disable an upstream OAuth provider
230    ///
231    /// Returns the disabled provider
232    ///
233    /// # Parameters
234    ///
235    /// * `clock`: The clock used to generate timestamps
236    /// * `provider`: The provider to disable
237    ///
238    /// # Errors
239    ///
240    /// Returns [`Self::Error`] if the underlying repository fails
241    async fn disable(
242        &mut self,
243        clock: &dyn Clock,
244        provider: UpstreamOAuthProvider,
245    ) -> Result<UpstreamOAuthProvider, Self::Error>;
246
247    /// List [`UpstreamOAuthProvider`] with the given filter and pagination
248    ///
249    /// # Parameters
250    ///
251    /// * `filter`: The filter to apply
252    /// * `pagination`: The pagination parameters
253    ///
254    /// # Errors
255    ///
256    /// Returns [`Self::Error`] if the underlying repository fails
257    async fn list(
258        &mut self,
259        filter: UpstreamOAuthProviderFilter<'_>,
260        pagination: Pagination,
261    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
262
263    /// Count the number of [`UpstreamOAuthProvider`] with the given filter
264    ///
265    /// # Parameters
266    ///
267    /// * `filter`: The filter to apply
268    ///
269    /// # Errors
270    ///
271    /// Returns [`Self::Error`] if the underlying repository fails
272    async fn count(
273        &mut self,
274        filter: UpstreamOAuthProviderFilter<'_>,
275    ) -> Result<usize, Self::Error>;
276
277    /// Get all enabled upstream OAuth providers
278    ///
279    /// # Errors
280    ///
281    /// Returns [`Self::Error`] if the underlying repository fails
282    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
283}
284
285repository_impl!(UpstreamOAuthProviderRepository:
286    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
287
288    async fn add(
289        &mut self,
290        rng: &mut (dyn RngCore + Send),
291        clock: &dyn Clock,
292        params: UpstreamOAuthProviderParams
293    ) -> Result<UpstreamOAuthProvider, Self::Error>;
294
295    async fn upsert(
296        &mut self,
297        clock: &dyn Clock,
298        id: Ulid,
299        params: UpstreamOAuthProviderParams
300    ) -> Result<UpstreamOAuthProvider, Self::Error>;
301
302    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
303
304    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
305
306    async fn disable(
307        &mut self,
308        clock: &dyn Clock,
309        provider: UpstreamOAuthProvider
310    ) -> Result<UpstreamOAuthProvider, Self::Error>;
311
312    async fn list(
313        &mut self,
314        filter: UpstreamOAuthProviderFilter<'_>,
315        pagination: Pagination
316    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
317
318    async fn count(
319        &mut self,
320        filter: UpstreamOAuthProviderFilter<'_>
321    ) -> Result<usize, Self::Error>;
322
323    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
324);