1use async_trait::async_trait;
11use mas_data_model::{Clock, User};
12use mas_storage::user::{UserFilter, UserRepository};
13use rand::RngCore;
14use sea_query::{Expr, PostgresQueryBuilder, Query, extension::postgres::PgExpr as _};
15use sea_query_binder::SqlxBinder;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{
21 DatabaseError,
22 filter::{Filter, StatementExt},
23 iden::Users,
24 pagination::QueryBuilderExt,
25 tracing::ExecuteExt,
26};
27
28mod email;
29mod password;
30mod recovery;
31mod registration;
32mod registration_token;
33mod session;
34mod terms;
35
36#[cfg(test)]
37mod tests;
38
39pub use self::{
40 email::PgUserEmailRepository, password::PgUserPasswordRepository,
41 recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
42 registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
43 terms::PgUserTermsRepository,
44};
45
46pub struct PgUserRepository<'c> {
48 conn: &'c mut PgConnection,
49}
50
51impl<'c> PgUserRepository<'c> {
52 pub fn new(conn: &'c mut PgConnection) -> Self {
54 Self { conn }
55 }
56}
57
58mod priv_ {
59 #![allow(missing_docs)]
62
63 use chrono::{DateTime, Utc};
64 use mas_storage::pagination::Node;
65 use sea_query::enum_def;
66 use ulid::Ulid;
67 use uuid::Uuid;
68
69 #[derive(Debug, Clone, sqlx::FromRow)]
70 #[enum_def]
71 pub(super) struct UserLookup {
72 pub(super) user_id: Uuid,
73 pub(super) username: String,
74 pub(super) created_at: DateTime<Utc>,
75 pub(super) locked_at: Option<DateTime<Utc>>,
76 pub(super) deactivated_at: Option<DateTime<Utc>>,
77 pub(super) can_request_admin: bool,
78 pub(super) is_guest: bool,
79 }
80
81 impl Node<Ulid> for UserLookup {
82 fn cursor(&self) -> Ulid {
83 self.user_id.into()
84 }
85 }
86}
87
88use priv_::{UserLookup, UserLookupIden};
89
90impl From<UserLookup> for User {
91 fn from(value: UserLookup) -> Self {
92 let id = value.user_id.into();
93 Self {
94 id,
95 username: value.username,
96 sub: id.to_string(),
97 created_at: value.created_at,
98 locked_at: value.locked_at,
99 deactivated_at: value.deactivated_at,
100 can_request_admin: value.can_request_admin,
101 is_guest: value.is_guest,
102 }
103 }
104}
105
106impl Filter for UserFilter<'_> {
107 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
108 sea_query::Condition::all()
109 .add_option(self.state().map(|state| {
110 match state {
111 mas_storage::user::UserState::Deactivated => {
112 Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
113 }
114 mas_storage::user::UserState::Locked => {
115 Expr::col((Users::Table, Users::LockedAt)).is_not_null()
116 }
117 mas_storage::user::UserState::Active => {
118 Expr::col((Users::Table, Users::LockedAt))
119 .is_null()
120 .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
121 }
122 }
123 }))
124 .add_option(self.can_request_admin().map(|can_request_admin| {
125 Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
126 }))
127 .add_option(
128 self.is_guest()
129 .map(|is_guest| Expr::col((Users::Table, Users::IsGuest)).eq(is_guest)),
130 )
131 .add_option(self.search().map(|search| {
132 Expr::col((Users::Table, Users::Username)).ilike(format!("%{search}%"))
133 }))
134 }
135}
136
137#[async_trait]
138impl UserRepository for PgUserRepository<'_> {
139 type Error = DatabaseError;
140
141 #[tracing::instrument(
142 name = "db.user.lookup",
143 skip_all,
144 fields(
145 db.query.text,
146 user.id = %id,
147 ),
148 err,
149 )]
150 async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
151 let res = sqlx::query_as!(
152 UserLookup,
153 r#"
154 SELECT user_id
155 , username
156 , created_at
157 , locked_at
158 , deactivated_at
159 , can_request_admin
160 , is_guest
161 FROM users
162 WHERE user_id = $1
163 "#,
164 Uuid::from(id),
165 )
166 .traced()
167 .fetch_optional(&mut *self.conn)
168 .await?;
169
170 let Some(res) = res else { return Ok(None) };
171
172 Ok(Some(res.into()))
173 }
174
175 #[tracing::instrument(
176 name = "db.user.find_by_username",
177 skip_all,
178 fields(
179 db.query.text,
180 user.username = username,
181 ),
182 err,
183 )]
184 async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
185 let res = sqlx::query_as!(
189 UserLookup,
190 r#"
191 SELECT user_id
192 , username
193 , created_at
194 , locked_at
195 , deactivated_at
196 , can_request_admin
197 , is_guest
198 FROM users
199 WHERE LOWER(username) = LOWER($1)
200 "#,
201 username,
202 )
203 .traced()
204 .fetch_all(&mut *self.conn)
205 .await?;
206
207 match &res[..] {
208 [user] => Ok(Some(user.clone().into())),
210 [] => Ok(None),
212 list => {
213 if let Some(user) = list.iter().find(|user| user.username == username) {
216 Ok(Some(user.clone().into()))
217 } else {
218 Ok(None)
220 }
221 }
222 }
223 }
224
225 #[tracing::instrument(
226 name = "db.user.add",
227 skip_all,
228 fields(
229 db.query.text,
230 user.username = username,
231 user.id,
232 ),
233 err,
234 )]
235 async fn add(
236 &mut self,
237 rng: &mut (dyn RngCore + Send),
238 clock: &dyn Clock,
239 username: String,
240 ) -> Result<User, Self::Error> {
241 let created_at = clock.now();
242 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
243 tracing::Span::current().record("user.id", tracing::field::display(id));
244
245 let res = sqlx::query!(
246 r#"
247 INSERT INTO users (user_id, username, created_at)
248 VALUES ($1, $2, $3)
249 ON CONFLICT (username) DO NOTHING
250 "#,
251 Uuid::from(id),
252 username,
253 created_at,
254 )
255 .traced()
256 .execute(&mut *self.conn)
257 .await?;
258
259 DatabaseError::ensure_affected_rows(&res, 1)?;
262
263 Ok(User {
264 id,
265 username,
266 sub: id.to_string(),
267 created_at,
268 locked_at: None,
269 deactivated_at: None,
270 can_request_admin: false,
271 is_guest: false,
272 })
273 }
274
275 #[tracing::instrument(
276 name = "db.user.exists",
277 skip_all,
278 fields(
279 db.query.text,
280 user.username = username,
281 ),
282 err,
283 )]
284 async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
285 let exists = sqlx::query_scalar!(
286 r#"
287 SELECT EXISTS(
288 SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
289 ) AS "exists!"
290 "#,
291 username
292 )
293 .traced()
294 .fetch_one(&mut *self.conn)
295 .await?;
296
297 Ok(exists)
298 }
299
300 #[tracing::instrument(
301 name = "db.user.lock",
302 skip_all,
303 fields(
304 db.query.text,
305 %user.id,
306 ),
307 err,
308 )]
309 async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
310 if user.locked_at.is_some() {
311 return Ok(user);
312 }
313
314 let locked_at = clock.now();
315 let res = sqlx::query!(
316 r#"
317 UPDATE users
318 SET locked_at = $1
319 WHERE user_id = $2
320 "#,
321 locked_at,
322 Uuid::from(user.id),
323 )
324 .traced()
325 .execute(&mut *self.conn)
326 .await?;
327
328 DatabaseError::ensure_affected_rows(&res, 1)?;
329
330 user.locked_at = Some(locked_at);
331
332 Ok(user)
333 }
334
335 #[tracing::instrument(
336 name = "db.user.unlock",
337 skip_all,
338 fields(
339 db.query.text,
340 %user.id,
341 ),
342 err,
343 )]
344 async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
345 if user.locked_at.is_none() {
346 return Ok(user);
347 }
348
349 let res = sqlx::query!(
350 r#"
351 UPDATE users
352 SET locked_at = NULL
353 WHERE user_id = $1
354 "#,
355 Uuid::from(user.id),
356 )
357 .traced()
358 .execute(&mut *self.conn)
359 .await?;
360
361 DatabaseError::ensure_affected_rows(&res, 1)?;
362
363 user.locked_at = None;
364
365 Ok(user)
366 }
367
368 #[tracing::instrument(
369 name = "db.user.deactivate",
370 skip_all,
371 fields(
372 db.query.text,
373 %user.id,
374 ),
375 err,
376 )]
377 async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
378 if user.deactivated_at.is_some() {
379 return Ok(user);
380 }
381
382 let deactivated_at = clock.now();
383 let res = sqlx::query!(
384 r#"
385 UPDATE users
386 SET deactivated_at = $2
387 WHERE user_id = $1
388 AND deactivated_at IS NULL
389 "#,
390 Uuid::from(user.id),
391 deactivated_at,
392 )
393 .traced()
394 .execute(&mut *self.conn)
395 .await?;
396
397 DatabaseError::ensure_affected_rows(&res, 1)?;
398
399 user.deactivated_at = Some(deactivated_at);
400
401 Ok(user)
402 }
403
404 #[tracing::instrument(
405 name = "db.user.reactivate",
406 skip_all,
407 fields(
408 db.query.text,
409 %user.id,
410 ),
411 err,
412 )]
413 async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
414 if user.deactivated_at.is_none() {
415 return Ok(user);
416 }
417
418 let res = sqlx::query!(
419 r#"
420 UPDATE users
421 SET deactivated_at = NULL
422 WHERE user_id = $1
423 "#,
424 Uuid::from(user.id),
425 )
426 .traced()
427 .execute(&mut *self.conn)
428 .await?;
429
430 DatabaseError::ensure_affected_rows(&res, 1)?;
431
432 user.deactivated_at = None;
433
434 Ok(user)
435 }
436
437 #[tracing::instrument(
438 name = "db.user.set_can_request_admin",
439 skip_all,
440 fields(
441 db.query.text,
442 %user.id,
443 user.can_request_admin = can_request_admin,
444 ),
445 err,
446 )]
447 async fn set_can_request_admin(
448 &mut self,
449 mut user: User,
450 can_request_admin: bool,
451 ) -> Result<User, Self::Error> {
452 let res = sqlx::query!(
453 r#"
454 UPDATE users
455 SET can_request_admin = $2
456 WHERE user_id = $1
457 "#,
458 Uuid::from(user.id),
459 can_request_admin,
460 )
461 .traced()
462 .execute(&mut *self.conn)
463 .await?;
464
465 DatabaseError::ensure_affected_rows(&res, 1)?;
466
467 user.can_request_admin = can_request_admin;
468
469 Ok(user)
470 }
471
472 #[tracing::instrument(
473 name = "db.user.list",
474 skip_all,
475 fields(
476 db.query.text,
477 ),
478 err,
479 )]
480 async fn list(
481 &mut self,
482 filter: UserFilter<'_>,
483 pagination: mas_storage::Pagination,
484 ) -> Result<mas_storage::Page<User>, Self::Error> {
485 let (sql, arguments) = Query::select()
486 .expr_as(
487 Expr::col((Users::Table, Users::UserId)),
488 UserLookupIden::UserId,
489 )
490 .expr_as(
491 Expr::col((Users::Table, Users::Username)),
492 UserLookupIden::Username,
493 )
494 .expr_as(
495 Expr::col((Users::Table, Users::CreatedAt)),
496 UserLookupIden::CreatedAt,
497 )
498 .expr_as(
499 Expr::col((Users::Table, Users::LockedAt)),
500 UserLookupIden::LockedAt,
501 )
502 .expr_as(
503 Expr::col((Users::Table, Users::DeactivatedAt)),
504 UserLookupIden::DeactivatedAt,
505 )
506 .expr_as(
507 Expr::col((Users::Table, Users::CanRequestAdmin)),
508 UserLookupIden::CanRequestAdmin,
509 )
510 .expr_as(
511 Expr::col((Users::Table, Users::IsGuest)),
512 UserLookupIden::IsGuest,
513 )
514 .from(Users::Table)
515 .apply_filter(filter)
516 .generate_pagination((Users::Table, Users::UserId), pagination)
517 .build_sqlx(PostgresQueryBuilder);
518
519 let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
520 .traced()
521 .fetch_all(&mut *self.conn)
522 .await?;
523
524 let page = pagination.process(edges).map(User::from);
525
526 Ok(page)
527 }
528
529 #[tracing::instrument(
530 name = "db.user.count",
531 skip_all,
532 fields(
533 db.query.text,
534 ),
535 err,
536 )]
537 async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
538 let (sql, arguments) = Query::select()
539 .expr(Expr::col((Users::Table, Users::UserId)).count())
540 .from(Users::Table)
541 .apply_filter(filter)
542 .build_sqlx(PostgresQueryBuilder);
543
544 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
545 .traced()
546 .fetch_one(&mut *self.conn)
547 .await?;
548
549 count
550 .try_into()
551 .map_err(DatabaseError::to_invalid_operation)
552 }
553
554 #[tracing::instrument(
555 name = "db.user.acquire_lock_for_sync",
556 skip_all,
557 fields(
558 db.query.text,
559 user.id = %user.id,
560 ),
561 err,
562 )]
563 async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
564 let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
572
573 sqlx::query!(
576 r#"
577 SELECT pg_advisory_xact_lock($1)
578 "#,
579 lock_id,
580 )
581 .traced()
582 .execute(&mut *self.conn)
583 .await?;
584
585 Ok(())
586 }
587}