diff options
| -rw-r--r-- | src/axum.rs | 84 | ||||
| -rw-r--r-- | tests/test_axum.rs | 42 | 
2 files changed, 125 insertions, 1 deletions
| diff --git a/src/axum.rs b/src/axum.rs index 0493b72..e9a3c17 100644 --- a/src/axum.rs +++ b/src/axum.rs @@ -96,6 +96,90 @@ where      }  } +/// Extractor that differentiates between the absence and presence of the query string +/// using `Option<T>`. Absence of query string encoded as `None`. Otherwise, it behaves +/// identical to the `QsQuery`. +/// +/// ## Example +/// +/// ```rust +/// # extern crate axum_framework as axum; +/// use serde_qs::axum::OptionalQsQuery; +/// use serde_qs::Config; +/// use axum::{response::IntoResponse, routing::get, Router, body::Body}; +/// +/// #[derive(serde::Deserialize)] +/// pub struct UsersFilter { +///    id: Vec<u64>, +/// } +/// +/// async fn filter_users( +///     OptionalQsQuery(info): OptionalQsQuery<UsersFilter> +/// ) -> impl IntoResponse { +///     match info { +///         Some(info) => todo!("Select users based on query string"), +///         None => { todo!("No query string provided")} +///     } +/// } +/// +/// fn main() { +///     let app = Router::<()>::new() +///         .route("/users", get(filter_users)); +/// } +#[derive(Clone, Copy, Default)] +pub struct OptionalQsQuery<T>(pub Option<T>); + +impl<T> std::ops::Deref for OptionalQsQuery<T> { +    type Target = Option<T>; + +    fn deref(&self) -> &Self::Target { +        &self.0 +    } +} + +impl<T> std::ops::DerefMut for OptionalQsQuery<T> { +    fn deref_mut(&mut self) -> &mut Self::Target { +        &mut self.0 +    } +} + +impl<T: std::fmt::Debug> std::fmt::Debug for OptionalQsQuery<T> { +    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +        self.0.fmt(f) +    } +} + +#[axum::async_trait] +impl<T, S> FromRequestParts<S> for OptionalQsQuery<T> +where +    T: serde::de::DeserializeOwned, +    S: Send + Sync, +{ +    type Rejection = QsQueryRejection; + +    async fn from_request_parts( +        parts: &mut axum::http::request::Parts, +        state: &S, +    ) -> Result<Self, Self::Rejection> { +        let Extension(qs_config) = Extension::<QsQueryConfig>::from_request_parts(parts, state) +            .await +            .unwrap_or_else(|_| Extension(QsQueryConfig::default())); +        if let Some(query) = parts.uri.query() { +            let error_handler = qs_config.error_handler.clone(); +            let config: QsConfig = qs_config.into(); +            config +                .deserialize_str::<T>(query) +                .map(|query| OptionalQsQuery(Some(query))) +                .map_err(|err| match error_handler { +                    Some(handler) => handler(err), +                    None => QsQueryRejection::new(err, StatusCode::BAD_REQUEST), +                }) +        } else { +            Ok(OptionalQsQuery(None)) +        } +    } +} +  #[derive(Debug)]  /// Rejection type for extractors that deserialize query strings  pub struct QsQueryRejection { diff --git a/tests/test_axum.rs b/tests/test_axum.rs index 16449b9..dd86a9e 100644 --- a/tests/test_axum.rs +++ b/tests/test_axum.rs @@ -8,7 +8,7 @@ extern crate axum_framework as axum;  extern crate serde_qs as qs;  use axum::{extract::FromRequestParts, http::StatusCode, response::IntoResponse}; -use qs::axum::{QsQuery, QsQueryConfig, QsQueryRejection}; +use qs::axum::{OptionalQsQuery, QsQuery, QsQueryConfig, QsQueryRejection};  use serde::de::Error;  fn from_str<'de, D, S>(deserializer: D) -> Result<S, D::Error> @@ -132,3 +132,43 @@ fn test_custom_qs_config() {          assert!(s.common.remaining);      })  } + +#[test] +fn test_optional_query_none() { +    futures::executor::block_on(async { +        let req = axum::http::Request::builder() +            .uri("/test") +            .body(()) +            .unwrap(); +        let (mut req_parts, _) = req.into_parts(); + +        let OptionalQsQuery(s) = OptionalQsQuery::<Query>::from_request_parts(&mut req_parts, &()) +            .await +            .unwrap(); + +        assert!(s.is_none()); +    }) +} + +#[test] +fn test_optional_query_some() { +    futures::executor::block_on(async { +        let req = axum::http::Request::builder() +            .uri("/test?foo=1&bars%5B%5D=3&limit=100&offset=50&remaining=true") +            .extension(QsQueryConfig::new(5, false)) +            .body(()) +            .unwrap(); + +        let (mut req_parts, _) = req.into_parts(); +        let OptionalQsQuery(s) = OptionalQsQuery::<Query>::from_request_parts(&mut req_parts, &()) +            .await +            .unwrap(); + +        let query = s.unwrap(); +        assert_eq!(query.foo, 1); +        assert_eq!(query.bars, vec![3]); +        assert_eq!(query.common.limit, 100); +        assert_eq!(query.common.offset, 50); +        assert!(query.common.remaining); +    }) +} | 
