summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDenis Kayshev <topenkoff@gmail.com>2022-04-29 18:56:23 +0300
committerGitHub <noreply@github.com>2022-04-29 11:56:23 -0400
commitaaad564074864ad463e369af7f98655b01565143 (patch)
tree6f56d7d8c372308dfd24bca76d24ec217dc39d86
parent4bd3699faba3f00dd8f59a358605abb6e485deb9 (diff)
Add axum support (#59)
-rw-r--r--.github/workflows/ci.yml4
-rw-r--r--Cargo.toml2
-rw-r--r--src/axum.rs225
-rw-r--r--src/lib.rs3
-rw-r--r--tests/test_axum.rs142
5 files changed, 376 insertions, 0 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 2f399ca..960e342 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -69,6 +69,10 @@ jobs:
run: |
cargo test --all-targets --features warp
cargo test --doc --features warp
+ - name: Run test axum
+ run: |
+ cargo test --all-targets --features axum
+ cargo test --doc --features axum
- name: Run test no feature
run: |
cargo test --all-targets
diff --git a/Cargo.toml b/Cargo.toml
index ad417fa..5f4336d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -22,6 +22,7 @@ serde = "1.0"
thiserror = "1.0"
tracing = { version = "0.1", optional = true }
warp-framework = { package = "warp", version = "0.3", default-features = false, optional = true }
+axum-framework = { package = "axum", version = "0.5", default-features = false, optional = true }
[dev-dependencies]
csv = "1.1"
@@ -38,6 +39,7 @@ actix2 = ["actix-web2", "futures"]
# deprecated feature -- used to return a warning
actix = []
warp = ["futures", "tracing", "warp-framework"]
+axum = ["axum-framework", "futures"]
[package.metadata.docs.rs]
features = ["actix4", "warp"]
diff --git a/src/axum.rs b/src/axum.rs
new file mode 100644
index 0000000..1605146
--- /dev/null
+++ b/src/axum.rs
@@ -0,0 +1,225 @@
+//! Functionality for using `serde_qs` with `axum`.
+//!
+//! Enable with the `axum` feature.
+
+use axum_framework as axum;
+
+use std::sync::Arc;
+
+use crate::de::Config as QsConfig;
+use crate::error::Error as QsError;
+
+use axum::{
+ extract::{Extension, FromRequest, RequestParts},
+ http::StatusCode,
+ response::{IntoResponse, Response},
+ BoxError, Error,
+};
+
+#[derive(Clone, Copy, Default)]
+/// Extract typed information from from the request's query.
+///
+/// ## Example
+///
+/// ```rust
+/// # extern crate axum_framework as axum;
+/// use serde_qs::axum::QsQuery;
+/// use serde_qs::Config;
+/// use axum::{response::IntoResponse, routing::get, Router, body::Body};
+///
+/// #[derive(serde::Deserialize)]
+/// pub struct UsersFilter {
+/// id: Vec<u64>,
+/// }
+///
+/// async fn filter_users(
+/// QsQuery(info): QsQuery<UsersFilter>
+/// ) -> impl IntoResponse {
+/// info.id
+/// .iter()
+/// .map(|i| i.to_string())
+/// .collect::<Vec<String>>()
+/// .join(", ")
+/// }
+///
+/// fn main() {
+/// let app = Router::<Body>::new()
+/// .route("/users", get(filter_users));
+/// }
+pub struct QsQuery<T>(pub T);
+
+impl<T> std::ops::Deref for QsQuery<T> {
+ type Target = T;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl<T: std::fmt::Display> std::fmt::Display for QsQuery<T> {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+impl<T: std::fmt::Debug> std::fmt::Debug for QsQuery<T> {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+#[axum::async_trait]
+impl<T, B> FromRequest<B> for QsQuery<T>
+where
+ T: serde::de::DeserializeOwned,
+ B: std::marker::Send,
+{
+ type Rejection = QsQueryRejection;
+
+ async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ let Extension(qs_config) = Extension::<QsQueryConfig>::from_request(req)
+ .await
+ .unwrap_or_else(|_| Extension(QsQueryConfig::default()));
+ let error_handler = qs_config.error_handler.clone();
+ let config: QsConfig = qs_config.into();
+ let query = req.uri().query().unwrap_or_default();
+ match config.deserialize_str::<T>(query) {
+ Ok(value) => Ok(QsQuery(value)),
+ Err(err) => match error_handler {
+ Some(handler) => Err((handler)(err)),
+ None => Err(QsQueryRejection::new(err, StatusCode::BAD_REQUEST)),
+ },
+ }
+ }
+}
+
+#[derive(Debug)]
+/// Rejection type for extractors that deserialize query strings
+pub struct QsQueryRejection {
+ error: axum::Error,
+ status: StatusCode,
+}
+
+impl std::fmt::Display for QsQueryRejection {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(
+ f,
+ "Failed to deserialize query string. Error: {}",
+ self.error,
+ )
+ }
+}
+
+impl QsQueryRejection {
+ /// Create new rejection
+ pub fn new<E>(error: E, status: StatusCode) -> Self
+ where
+ E: Into<BoxError>,
+ {
+ QsQueryRejection {
+ error: Error::new(error),
+ status,
+ }
+ }
+}
+
+impl IntoResponse for QsQueryRejection {
+ fn into_response(self) -> Response {
+ let mut res = self.to_string().into_response();
+ *res.status_mut() = self.status;
+ res
+ }
+}
+
+impl std::error::Error for QsQueryRejection {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ Some(&self.error)
+ }
+}
+
+#[derive(Clone)]
+/// Query extractor configuration
+///
+/// QsQueryConfig wraps [`Config`](crate::de::Config) and implement [`Clone`]
+/// for [`FromRequest`](https://docs.rs/axum/0.5/axum/extract/trait.FromRequest.html)
+///
+/// ## Example
+///
+/// ```rust
+/// # extern crate axum_framework as axum;
+/// use serde_qs::axum::{QsQuery, QsQueryConfig, QsQueryRejection};
+/// use serde_qs::Config;
+/// use axum::{
+/// response::IntoResponse,
+/// routing::get,
+/// Router,
+/// body::Body,
+/// extract::Extension,
+/// http::StatusCode,
+/// };
+/// use std::sync::Arc;
+///
+/// #[derive(serde::Deserialize)]
+/// pub struct UsersFilter {
+/// id: Vec<u64>,
+/// }
+///
+/// async fn filter_users(
+/// QsQuery(info): QsQuery<UsersFilter>
+/// ) -> impl IntoResponse {
+/// info.id
+/// .iter()
+/// .map(|i| i.to_string())
+/// .collect::<Vec<String>>()
+/// .join(", ")
+/// }
+///
+/// fn main() {
+/// let app = Router::<Body>::new()
+/// .route("/users", get(filter_users))
+/// .layer(Extension(Arc::new(QsQueryConfig::new(5, false)
+/// .error_handler(|err| {
+/// QsQueryRejection::new(err, StatusCode::UNPROCESSABLE_ENTITY)
+/// }))));
+/// }
+pub struct QsQueryConfig {
+ max_depth: usize,
+ strict: bool,
+ error_handler: Option<Arc<dyn Fn(QsError) -> QsQueryRejection + Send + Sync>>,
+}
+
+impl QsQueryConfig {
+ /// Create new config wrapper
+ pub fn new(max_depth: usize, strict: bool) -> Self {
+ Self {
+ max_depth,
+ strict,
+ error_handler: None,
+ }
+ }
+
+ /// Set custom error handler
+ pub fn error_handler<F>(mut self, f: F) -> Self
+ where
+ F: Fn(QsError) -> QsQueryRejection + Send + Sync + 'static,
+ {
+ self.error_handler = Some(Arc::new(f));
+ self
+ }
+}
+
+impl From<QsQueryConfig> for QsConfig {
+ fn from(config: QsQueryConfig) -> Self {
+ Self::new(config.max_depth, config.strict)
+ }
+}
+
+impl Default for QsQueryConfig {
+ fn default() -> Self {
+ Self {
+ max_depth: 5,
+ strict: true,
+ error_handler: None,
+ }
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 803b401..96ce141 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -200,3 +200,6 @@ pub use de::{from_bytes, from_str};
pub use error::Error;
#[doc(inline)]
pub use ser::{to_string, to_writer, QsSerializer};
+
+#[cfg(feature = "axum")]
+pub mod axum;
diff --git a/tests/test_axum.rs b/tests/test_axum.rs
new file mode 100644
index 0000000..0a07b19
--- /dev/null
+++ b/tests/test_axum.rs
@@ -0,0 +1,142 @@
+#![cfg(feature = "axum")]
+
+extern crate serde;
+
+#[macro_use]
+extern crate serde_derive;
+extern crate axum_framework as axum;
+extern crate serde_qs as qs;
+
+use axum::{
+ extract::{FromRequest, RequestParts},
+ http::StatusCode,
+ response::IntoResponse,
+};
+use qs::axum::{QsQuery, QsQueryConfig, QsQueryRejection};
+use serde::de::Error;
+
+fn from_str<'de, D, S>(deserializer: D) -> Result<S, D::Error>
+where
+ D: serde::Deserializer<'de>,
+ S: std::str::FromStr,
+{
+ let s = <&str as serde::Deserialize>::deserialize(deserializer)?;
+ S::from_str(s).map_err(|_| D::Error::custom("could not parse string"))
+}
+
+#[derive(Deserialize, Serialize, Debug, PartialEq)]
+struct Query {
+ foo: u64,
+ bars: Vec<u64>,
+ #[serde(flatten)]
+ common: CommonParams,
+}
+
+#[derive(Deserialize, Serialize, Debug, PartialEq)]
+struct CommonParams {
+ #[serde(deserialize_with = "from_str")]
+ limit: u64,
+ #[serde(deserialize_with = "from_str")]
+ offset: u64,
+ #[serde(deserialize_with = "from_str")]
+ remaining: bool,
+}
+
+#[test]
+fn test_default_error_handler() {
+ futures::executor::block_on(async {
+ let req = axum::http::Request::builder()
+ .uri("/test")
+ .body(())
+ .unwrap();
+ let mut req_parts = RequestParts::new(req);
+
+ let e = QsQuery::<Query>::from_request(&mut req_parts)
+ .await
+ .unwrap_err();
+
+ assert_eq!(e.into_response().status(), StatusCode::BAD_REQUEST);
+ })
+}
+
+#[test]
+fn test_custom_error_handler() {
+ futures::executor::block_on(async {
+ let req =
+ axum::http::Request::builder()
+ .uri("/test?foo=1&bars%5B%5D=3&limit=100&offset=50&remaining=true")
+ .extension(QsQueryConfig::default().error_handler(|err| {
+ QsQueryRejection::new(err, StatusCode::UNPROCESSABLE_ENTITY)
+ }))
+ .body(())
+ .unwrap();
+ let mut req_parts = RequestParts::new(req);
+
+ let query = QsQuery::<Query>::from_request(&mut req_parts).await;
+
+ assert!(query.is_err());
+ assert_eq!(
+ query.unwrap_err().into_response().status(),
+ StatusCode::UNPROCESSABLE_ENTITY
+ );
+ })
+}
+
+#[test]
+fn test_composite_querystring_extractor() {
+ futures::executor::block_on(async {
+ let req = axum::http::Request::builder()
+ .uri("/test?foo=1&bars[]=0&bars[]=1&limit=100&offset=50&remaining=true")
+ .body(())
+ .unwrap();
+ let mut req_parts = RequestParts::new(req);
+
+ let s = QsQuery::<Query>::from_request(&mut req_parts)
+ .await
+ .unwrap();
+ assert_eq!(s.foo, 1);
+ assert_eq!(s.bars, vec![0, 1]);
+ assert_eq!(s.common.limit, 100);
+ assert_eq!(s.common.offset, 50);
+ assert!(s.common.remaining);
+ })
+}
+
+#[test]
+fn test_default_qs_config() {
+ futures::executor::block_on(async {
+ let req = axum::http::Request::builder()
+ .uri("/test?foo=1&bars%5B%5D=3&limit=100&offset=50&remaining=true")
+ .body(())
+ .unwrap();
+ let mut req_parts = RequestParts::new(req);
+
+ let e = QsQuery::<Query>::from_request(&mut req_parts)
+ .await
+ .unwrap_err();
+
+ assert_eq!(e.into_response().status(), StatusCode::BAD_REQUEST);
+ })
+}
+
+#[test]
+fn test_custom_qs_config() {
+ futures::executor::block_on(async {
+ let req = axum::http::Request::builder()
+ .uri("/test?foo=1&bars%5B%5D=3&limit=100&offset=50&remaining=true")
+ .extension(QsQueryConfig::new(5, false))
+ .body(())
+ .unwrap();
+
+ let mut req_parts = RequestParts::new(req);
+
+ let s = QsQuery::<Query>::from_request(&mut req_parts)
+ .await
+ .unwrap();
+ assert_eq!(s.foo, 1);
+ assert_eq!(s.bars, vec![3]);
+ assert_eq!(s.common.limit, 100);
+ assert_eq!(s.common.offset, 50);
+ assert!(s.common.remaining);
+ })
+}