From 776843be76e58c2045e2f80e75a72bfd9cac9d83 Mon Sep 17 00:00:00 2001 From: Cyril Plisko Date: Tue, 4 Mar 2025 20:01:35 +0200 Subject: Add axum::OptionalQsQuery (#102) Addresses #101 --- src/axum.rs | 84 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_axum.rs | 42 ++++++++++++++++++++++++++- 2 files changed, 125 insertions(+), 1 deletion(-) diff --git a/src/axum.rs b/src/axum.rs index 0493b72..e9a3c17 100644 --- a/src/axum.rs +++ b/src/axum.rs @@ -96,6 +96,90 @@ where } } +/// Extractor that differentiates between the absence and presence of the query string +/// using `Option`. Absence of query string encoded as `None`. Otherwise, it behaves +/// identical to the `QsQuery`. +/// +/// ## Example +/// +/// ```rust +/// # extern crate axum_framework as axum; +/// use serde_qs::axum::OptionalQsQuery; +/// use serde_qs::Config; +/// use axum::{response::IntoResponse, routing::get, Router, body::Body}; +/// +/// #[derive(serde::Deserialize)] +/// pub struct UsersFilter { +/// id: Vec, +/// } +/// +/// async fn filter_users( +/// OptionalQsQuery(info): OptionalQsQuery +/// ) -> impl IntoResponse { +/// match info { +/// Some(info) => todo!("Select users based on query string"), +/// None => { todo!("No query string provided")} +/// } +/// } +/// +/// fn main() { +/// let app = Router::<()>::new() +/// .route("/users", get(filter_users)); +/// } +#[derive(Clone, Copy, Default)] +pub struct OptionalQsQuery(pub Option); + +impl std::ops::Deref for OptionalQsQuery { + type Target = Option; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for OptionalQsQuery { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl std::fmt::Debug for OptionalQsQuery { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[axum::async_trait] +impl FromRequestParts for OptionalQsQuery +where + T: serde::de::DeserializeOwned, + S: Send + Sync, +{ + type Rejection = QsQueryRejection; + + async fn from_request_parts( + parts: &mut axum::http::request::Parts, + state: &S, + ) -> Result { + let Extension(qs_config) = Extension::::from_request_parts(parts, state) + .await + .unwrap_or_else(|_| Extension(QsQueryConfig::default())); + if let Some(query) = parts.uri.query() { + let error_handler = qs_config.error_handler.clone(); + let config: QsConfig = qs_config.into(); + config + .deserialize_str::(query) + .map(|query| OptionalQsQuery(Some(query))) + .map_err(|err| match error_handler { + Some(handler) => handler(err), + None => QsQueryRejection::new(err, StatusCode::BAD_REQUEST), + }) + } else { + Ok(OptionalQsQuery(None)) + } + } +} + #[derive(Debug)] /// Rejection type for extractors that deserialize query strings pub struct QsQueryRejection { diff --git a/tests/test_axum.rs b/tests/test_axum.rs index 16449b9..dd86a9e 100644 --- a/tests/test_axum.rs +++ b/tests/test_axum.rs @@ -8,7 +8,7 @@ extern crate axum_framework as axum; extern crate serde_qs as qs; use axum::{extract::FromRequestParts, http::StatusCode, response::IntoResponse}; -use qs::axum::{QsQuery, QsQueryConfig, QsQueryRejection}; +use qs::axum::{OptionalQsQuery, QsQuery, QsQueryConfig, QsQueryRejection}; use serde::de::Error; fn from_str<'de, D, S>(deserializer: D) -> Result @@ -132,3 +132,43 @@ fn test_custom_qs_config() { assert!(s.common.remaining); }) } + +#[test] +fn test_optional_query_none() { + futures::executor::block_on(async { + let req = axum::http::Request::builder() + .uri("/test") + .body(()) + .unwrap(); + let (mut req_parts, _) = req.into_parts(); + + let OptionalQsQuery(s) = OptionalQsQuery::::from_request_parts(&mut req_parts, &()) + .await + .unwrap(); + + assert!(s.is_none()); + }) +} + +#[test] +fn test_optional_query_some() { + futures::executor::block_on(async { + let req = axum::http::Request::builder() + .uri("/test?foo=1&bars%5B%5D=3&limit=100&offset=50&remaining=true") + .extension(QsQueryConfig::new(5, false)) + .body(()) + .unwrap(); + + let (mut req_parts, _) = req.into_parts(); + let OptionalQsQuery(s) = OptionalQsQuery::::from_request_parts(&mut req_parts, &()) + .await + .unwrap(); + + let query = s.unwrap(); + assert_eq!(query.foo, 1); + assert_eq!(query.bars, vec![3]); + assert_eq!(query.common.limit, 100); + assert_eq!(query.common.offset, 50); + assert!(query.common.remaining); + }) +} -- cgit v1.2.3