summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/axum.rs84
-rw-r--r--tests/test_axum.rs42
2 files changed, 125 insertions, 1 deletions
diff --git a/src/axum.rs b/src/axum.rs
index 0493b72..e9a3c17 100644
--- a/src/axum.rs
+++ b/src/axum.rs
@@ -96,6 +96,90 @@ where
}
}
+/// Extractor that differentiates between the absence and presence of the query string
+/// using `Option<T>`. Absence of query string encoded as `None`. Otherwise, it behaves
+/// identical to the `QsQuery`.
+///
+/// ## Example
+///
+/// ```rust
+/// # extern crate axum_framework as axum;
+/// use serde_qs::axum::OptionalQsQuery;
+/// use serde_qs::Config;
+/// use axum::{response::IntoResponse, routing::get, Router, body::Body};
+///
+/// #[derive(serde::Deserialize)]
+/// pub struct UsersFilter {
+/// id: Vec<u64>,
+/// }
+///
+/// async fn filter_users(
+/// OptionalQsQuery(info): OptionalQsQuery<UsersFilter>
+/// ) -> impl IntoResponse {
+/// match info {
+/// Some(info) => todo!("Select users based on query string"),
+/// None => { todo!("No query string provided")}
+/// }
+/// }
+///
+/// fn main() {
+/// let app = Router::<()>::new()
+/// .route("/users", get(filter_users));
+/// }
+#[derive(Clone, Copy, Default)]
+pub struct OptionalQsQuery<T>(pub Option<T>);
+
+impl<T> std::ops::Deref for OptionalQsQuery<T> {
+ type Target = Option<T>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl<T> std::ops::DerefMut for OptionalQsQuery<T> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+impl<T: std::fmt::Debug> std::fmt::Debug for OptionalQsQuery<T> {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+#[axum::async_trait]
+impl<T, S> FromRequestParts<S> for OptionalQsQuery<T>
+where
+ T: serde::de::DeserializeOwned,
+ S: Send + Sync,
+{
+ type Rejection = QsQueryRejection;
+
+ async fn from_request_parts(
+ parts: &mut axum::http::request::Parts,
+ state: &S,
+ ) -> Result<Self, Self::Rejection> {
+ let Extension(qs_config) = Extension::<QsQueryConfig>::from_request_parts(parts, state)
+ .await
+ .unwrap_or_else(|_| Extension(QsQueryConfig::default()));
+ if let Some(query) = parts.uri.query() {
+ let error_handler = qs_config.error_handler.clone();
+ let config: QsConfig = qs_config.into();
+ config
+ .deserialize_str::<T>(query)
+ .map(|query| OptionalQsQuery(Some(query)))
+ .map_err(|err| match error_handler {
+ Some(handler) => handler(err),
+ None => QsQueryRejection::new(err, StatusCode::BAD_REQUEST),
+ })
+ } else {
+ Ok(OptionalQsQuery(None))
+ }
+ }
+}
+
#[derive(Debug)]
/// Rejection type for extractors that deserialize query strings
pub struct QsQueryRejection {
diff --git a/tests/test_axum.rs b/tests/test_axum.rs
index 16449b9..dd86a9e 100644
--- a/tests/test_axum.rs
+++ b/tests/test_axum.rs
@@ -8,7 +8,7 @@ extern crate axum_framework as axum;
extern crate serde_qs as qs;
use axum::{extract::FromRequestParts, http::StatusCode, response::IntoResponse};
-use qs::axum::{QsQuery, QsQueryConfig, QsQueryRejection};
+use qs::axum::{OptionalQsQuery, QsQuery, QsQueryConfig, QsQueryRejection};
use serde::de::Error;
fn from_str<'de, D, S>(deserializer: D) -> Result<S, D::Error>
@@ -132,3 +132,43 @@ fn test_custom_qs_config() {
assert!(s.common.remaining);
})
}
+
+#[test]
+fn test_optional_query_none() {
+ futures::executor::block_on(async {
+ let req = axum::http::Request::builder()
+ .uri("/test")
+ .body(())
+ .unwrap();
+ let (mut req_parts, _) = req.into_parts();
+
+ let OptionalQsQuery(s) = OptionalQsQuery::<Query>::from_request_parts(&mut req_parts, &())
+ .await
+ .unwrap();
+
+ assert!(s.is_none());
+ })
+}
+
+#[test]
+fn test_optional_query_some() {
+ futures::executor::block_on(async {
+ let req = axum::http::Request::builder()
+ .uri("/test?foo=1&bars%5B%5D=3&limit=100&offset=50&remaining=true")
+ .extension(QsQueryConfig::new(5, false))
+ .body(())
+ .unwrap();
+
+ let (mut req_parts, _) = req.into_parts();
+ let OptionalQsQuery(s) = OptionalQsQuery::<Query>::from_request_parts(&mut req_parts, &())
+ .await
+ .unwrap();
+
+ let query = s.unwrap();
+ assert_eq!(query.foo, 1);
+ assert_eq!(query.bars, vec![3]);
+ assert_eq!(query.common.limit, 100);
+ assert_eq!(query.common.offset, 50);
+ assert!(query.common.remaining);
+ })
+}