diff options
author | Denis Kayshev <topenkoff@gmail.com> | 2022-04-29 18:56:23 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-29 11:56:23 -0400 |
commit | aaad564074864ad463e369af7f98655b01565143 (patch) | |
tree | 6f56d7d8c372308dfd24bca76d24ec217dc39d86 | |
parent | 4bd3699faba3f00dd8f59a358605abb6e485deb9 (diff) |
Add axum support (#59)
-rw-r--r-- | .github/workflows/ci.yml | 4 | ||||
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | src/axum.rs | 225 | ||||
-rw-r--r-- | src/lib.rs | 3 | ||||
-rw-r--r-- | tests/test_axum.rs | 142 |
5 files changed, 376 insertions, 0 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2f399ca..960e342 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,6 +69,10 @@ jobs: run: | cargo test --all-targets --features warp cargo test --doc --features warp + - name: Run test axum + run: | + cargo test --all-targets --features axum + cargo test --doc --features axum - name: Run test no feature run: | cargo test --all-targets @@ -22,6 +22,7 @@ serde = "1.0" thiserror = "1.0" tracing = { version = "0.1", optional = true } warp-framework = { package = "warp", version = "0.3", default-features = false, optional = true } +axum-framework = { package = "axum", version = "0.5", default-features = false, optional = true } [dev-dependencies] csv = "1.1" @@ -38,6 +39,7 @@ actix2 = ["actix-web2", "futures"] # deprecated feature -- used to return a warning actix = [] warp = ["futures", "tracing", "warp-framework"] +axum = ["axum-framework", "futures"] [package.metadata.docs.rs] features = ["actix4", "warp"] diff --git a/src/axum.rs b/src/axum.rs new file mode 100644 index 0000000..1605146 --- /dev/null +++ b/src/axum.rs @@ -0,0 +1,225 @@ +//! Functionality for using `serde_qs` with `axum`. +//! +//! Enable with the `axum` feature. + +use axum_framework as axum; + +use std::sync::Arc; + +use crate::de::Config as QsConfig; +use crate::error::Error as QsError; + +use axum::{ + extract::{Extension, FromRequest, RequestParts}, + http::StatusCode, + response::{IntoResponse, Response}, + BoxError, Error, +}; + +#[derive(Clone, Copy, Default)] +/// Extract typed information from from the request's query. +/// +/// ## Example +/// +/// ```rust +/// # extern crate axum_framework as axum; +/// use serde_qs::axum::QsQuery; +/// use serde_qs::Config; +/// use axum::{response::IntoResponse, routing::get, Router, body::Body}; +/// +/// #[derive(serde::Deserialize)] +/// pub struct UsersFilter { +/// id: Vec<u64>, +/// } +/// +/// async fn filter_users( +/// QsQuery(info): QsQuery<UsersFilter> +/// ) -> impl IntoResponse { +/// info.id +/// .iter() +/// .map(|i| i.to_string()) +/// .collect::<Vec<String>>() +/// .join(", ") +/// } +/// +/// fn main() { +/// let app = Router::<Body>::new() +/// .route("/users", get(filter_users)); +/// } +pub struct QsQuery<T>(pub T); + +impl<T> std::ops::Deref for QsQuery<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<T: std::fmt::Display> std::fmt::Display for QsQuery<T> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl<T: std::fmt::Debug> std::fmt::Debug for QsQuery<T> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[axum::async_trait] +impl<T, B> FromRequest<B> for QsQuery<T> +where + T: serde::de::DeserializeOwned, + B: std::marker::Send, +{ + type Rejection = QsQueryRejection; + + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + let Extension(qs_config) = Extension::<QsQueryConfig>::from_request(req) + .await + .unwrap_or_else(|_| Extension(QsQueryConfig::default())); + let error_handler = qs_config.error_handler.clone(); + let config: QsConfig = qs_config.into(); + let query = req.uri().query().unwrap_or_default(); + match config.deserialize_str::<T>(query) { + Ok(value) => Ok(QsQuery(value)), + Err(err) => match error_handler { + Some(handler) => Err((handler)(err)), + None => Err(QsQueryRejection::new(err, StatusCode::BAD_REQUEST)), + }, + } + } +} + +#[derive(Debug)] +/// Rejection type for extractors that deserialize query strings +pub struct QsQueryRejection { + error: axum::Error, + status: StatusCode, +} + +impl std::fmt::Display for QsQueryRejection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Failed to deserialize query string. Error: {}", + self.error, + ) + } +} + +impl QsQueryRejection { + /// Create new rejection + pub fn new<E>(error: E, status: StatusCode) -> Self + where + E: Into<BoxError>, + { + QsQueryRejection { + error: Error::new(error), + status, + } + } +} + +impl IntoResponse for QsQueryRejection { + fn into_response(self) -> Response { + let mut res = self.to_string().into_response(); + *res.status_mut() = self.status; + res + } +} + +impl std::error::Error for QsQueryRejection { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.error) + } +} + +#[derive(Clone)] +/// Query extractor configuration +/// +/// QsQueryConfig wraps [`Config`](crate::de::Config) and implement [`Clone`] +/// for [`FromRequest`](https://docs.rs/axum/0.5/axum/extract/trait.FromRequest.html) +/// +/// ## Example +/// +/// ```rust +/// # extern crate axum_framework as axum; +/// use serde_qs::axum::{QsQuery, QsQueryConfig, QsQueryRejection}; +/// use serde_qs::Config; +/// use axum::{ +/// response::IntoResponse, +/// routing::get, +/// Router, +/// body::Body, +/// extract::Extension, +/// http::StatusCode, +/// }; +/// use std::sync::Arc; +/// +/// #[derive(serde::Deserialize)] +/// pub struct UsersFilter { +/// id: Vec<u64>, +/// } +/// +/// async fn filter_users( +/// QsQuery(info): QsQuery<UsersFilter> +/// ) -> impl IntoResponse { +/// info.id +/// .iter() +/// .map(|i| i.to_string()) +/// .collect::<Vec<String>>() +/// .join(", ") +/// } +/// +/// fn main() { +/// let app = Router::<Body>::new() +/// .route("/users", get(filter_users)) +/// .layer(Extension(Arc::new(QsQueryConfig::new(5, false) +/// .error_handler(|err| { +/// QsQueryRejection::new(err, StatusCode::UNPROCESSABLE_ENTITY) +/// })))); +/// } +pub struct QsQueryConfig { + max_depth: usize, + strict: bool, + error_handler: Option<Arc<dyn Fn(QsError) -> QsQueryRejection + Send + Sync>>, +} + +impl QsQueryConfig { + /// Create new config wrapper + pub fn new(max_depth: usize, strict: bool) -> Self { + Self { + max_depth, + strict, + error_handler: None, + } + } + + /// Set custom error handler + pub fn error_handler<F>(mut self, f: F) -> Self + where + F: Fn(QsError) -> QsQueryRejection + Send + Sync + 'static, + { + self.error_handler = Some(Arc::new(f)); + self + } +} + +impl From<QsQueryConfig> for QsConfig { + fn from(config: QsQueryConfig) -> Self { + Self::new(config.max_depth, config.strict) + } +} + +impl Default for QsQueryConfig { + fn default() -> Self { + Self { + max_depth: 5, + strict: true, + error_handler: None, + } + } +} @@ -200,3 +200,6 @@ pub use de::{from_bytes, from_str}; pub use error::Error; #[doc(inline)] pub use ser::{to_string, to_writer, QsSerializer}; + +#[cfg(feature = "axum")] +pub mod axum; diff --git a/tests/test_axum.rs b/tests/test_axum.rs new file mode 100644 index 0000000..0a07b19 --- /dev/null +++ b/tests/test_axum.rs @@ -0,0 +1,142 @@ +#![cfg(feature = "axum")] + +extern crate serde; + +#[macro_use] +extern crate serde_derive; +extern crate axum_framework as axum; +extern crate serde_qs as qs; + +use axum::{ + extract::{FromRequest, RequestParts}, + http::StatusCode, + response::IntoResponse, +}; +use qs::axum::{QsQuery, QsQueryConfig, QsQueryRejection}; +use serde::de::Error; + +fn from_str<'de, D, S>(deserializer: D) -> Result<S, D::Error> +where + D: serde::Deserializer<'de>, + S: std::str::FromStr, +{ + let s = <&str as serde::Deserialize>::deserialize(deserializer)?; + S::from_str(s).map_err(|_| D::Error::custom("could not parse string")) +} + +#[derive(Deserialize, Serialize, Debug, PartialEq)] +struct Query { + foo: u64, + bars: Vec<u64>, + #[serde(flatten)] + common: CommonParams, +} + +#[derive(Deserialize, Serialize, Debug, PartialEq)] +struct CommonParams { + #[serde(deserialize_with = "from_str")] + limit: u64, + #[serde(deserialize_with = "from_str")] + offset: u64, + #[serde(deserialize_with = "from_str")] + remaining: bool, +} + +#[test] +fn test_default_error_handler() { + futures::executor::block_on(async { + let req = axum::http::Request::builder() + .uri("/test") + .body(()) + .unwrap(); + let mut req_parts = RequestParts::new(req); + + let e = QsQuery::<Query>::from_request(&mut req_parts) + .await + .unwrap_err(); + + assert_eq!(e.into_response().status(), StatusCode::BAD_REQUEST); + }) +} + +#[test] +fn test_custom_error_handler() { + futures::executor::block_on(async { + let req = + axum::http::Request::builder() + .uri("/test?foo=1&bars%5B%5D=3&limit=100&offset=50&remaining=true") + .extension(QsQueryConfig::default().error_handler(|err| { + QsQueryRejection::new(err, StatusCode::UNPROCESSABLE_ENTITY) + })) + .body(()) + .unwrap(); + let mut req_parts = RequestParts::new(req); + + let query = QsQuery::<Query>::from_request(&mut req_parts).await; + + assert!(query.is_err()); + assert_eq!( + query.unwrap_err().into_response().status(), + StatusCode::UNPROCESSABLE_ENTITY + ); + }) +} + +#[test] +fn test_composite_querystring_extractor() { + futures::executor::block_on(async { + let req = axum::http::Request::builder() + .uri("/test?foo=1&bars[]=0&bars[]=1&limit=100&offset=50&remaining=true") + .body(()) + .unwrap(); + let mut req_parts = RequestParts::new(req); + + let s = QsQuery::<Query>::from_request(&mut req_parts) + .await + .unwrap(); + assert_eq!(s.foo, 1); + assert_eq!(s.bars, vec![0, 1]); + assert_eq!(s.common.limit, 100); + assert_eq!(s.common.offset, 50); + assert!(s.common.remaining); + }) +} + +#[test] +fn test_default_qs_config() { + futures::executor::block_on(async { + let req = axum::http::Request::builder() + .uri("/test?foo=1&bars%5B%5D=3&limit=100&offset=50&remaining=true") + .body(()) + .unwrap(); + let mut req_parts = RequestParts::new(req); + + let e = QsQuery::<Query>::from_request(&mut req_parts) + .await + .unwrap_err(); + + assert_eq!(e.into_response().status(), StatusCode::BAD_REQUEST); + }) +} + +#[test] +fn test_custom_qs_config() { + futures::executor::block_on(async { + let req = axum::http::Request::builder() + .uri("/test?foo=1&bars%5B%5D=3&limit=100&offset=50&remaining=true") + .extension(QsQueryConfig::new(5, false)) + .body(()) + .unwrap(); + + let mut req_parts = RequestParts::new(req); + + let s = QsQuery::<Query>::from_request(&mut req_parts) + .await + .unwrap(); + assert_eq!(s.foo, 1); + assert_eq!(s.bars, vec![3]); + assert_eq!(s.common.limit, 100); + assert_eq!(s.common.offset, 50); + assert!(s.common.remaining); + }) +} |