diff options
author | Denis Kayshev <topenkoff@gmail.com> | 2022-04-29 18:56:23 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-29 11:56:23 -0400 |
commit | aaad564074864ad463e369af7f98655b01565143 (patch) | |
tree | 6f56d7d8c372308dfd24bca76d24ec217dc39d86 /src | |
parent | 4bd3699faba3f00dd8f59a358605abb6e485deb9 (diff) |
Add axum support (#59)
Diffstat (limited to 'src')
-rw-r--r-- | src/axum.rs | 225 | ||||
-rw-r--r-- | src/lib.rs | 3 |
2 files changed, 228 insertions, 0 deletions
diff --git a/src/axum.rs b/src/axum.rs new file mode 100644 index 0000000..1605146 --- /dev/null +++ b/src/axum.rs @@ -0,0 +1,225 @@ +//! Functionality for using `serde_qs` with `axum`. +//! +//! Enable with the `axum` feature. + +use axum_framework as axum; + +use std::sync::Arc; + +use crate::de::Config as QsConfig; +use crate::error::Error as QsError; + +use axum::{ + extract::{Extension, FromRequest, RequestParts}, + http::StatusCode, + response::{IntoResponse, Response}, + BoxError, Error, +}; + +#[derive(Clone, Copy, Default)] +/// Extract typed information from from the request's query. +/// +/// ## Example +/// +/// ```rust +/// # extern crate axum_framework as axum; +/// use serde_qs::axum::QsQuery; +/// use serde_qs::Config; +/// use axum::{response::IntoResponse, routing::get, Router, body::Body}; +/// +/// #[derive(serde::Deserialize)] +/// pub struct UsersFilter { +/// id: Vec<u64>, +/// } +/// +/// async fn filter_users( +/// QsQuery(info): QsQuery<UsersFilter> +/// ) -> impl IntoResponse { +/// info.id +/// .iter() +/// .map(|i| i.to_string()) +/// .collect::<Vec<String>>() +/// .join(", ") +/// } +/// +/// fn main() { +/// let app = Router::<Body>::new() +/// .route("/users", get(filter_users)); +/// } +pub struct QsQuery<T>(pub T); + +impl<T> std::ops::Deref for QsQuery<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<T: std::fmt::Display> std::fmt::Display for QsQuery<T> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl<T: std::fmt::Debug> std::fmt::Debug for QsQuery<T> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[axum::async_trait] +impl<T, B> FromRequest<B> for QsQuery<T> +where + T: serde::de::DeserializeOwned, + B: std::marker::Send, +{ + type Rejection = QsQueryRejection; + + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + let Extension(qs_config) = Extension::<QsQueryConfig>::from_request(req) + .await + .unwrap_or_else(|_| Extension(QsQueryConfig::default())); + let error_handler = qs_config.error_handler.clone(); + let config: QsConfig = qs_config.into(); + let query = req.uri().query().unwrap_or_default(); + match config.deserialize_str::<T>(query) { + Ok(value) => Ok(QsQuery(value)), + Err(err) => match error_handler { + Some(handler) => Err((handler)(err)), + None => Err(QsQueryRejection::new(err, StatusCode::BAD_REQUEST)), + }, + } + } +} + +#[derive(Debug)] +/// Rejection type for extractors that deserialize query strings +pub struct QsQueryRejection { + error: axum::Error, + status: StatusCode, +} + +impl std::fmt::Display for QsQueryRejection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Failed to deserialize query string. Error: {}", + self.error, + ) + } +} + +impl QsQueryRejection { + /// Create new rejection + pub fn new<E>(error: E, status: StatusCode) -> Self + where + E: Into<BoxError>, + { + QsQueryRejection { + error: Error::new(error), + status, + } + } +} + +impl IntoResponse for QsQueryRejection { + fn into_response(self) -> Response { + let mut res = self.to_string().into_response(); + *res.status_mut() = self.status; + res + } +} + +impl std::error::Error for QsQueryRejection { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.error) + } +} + +#[derive(Clone)] +/// Query extractor configuration +/// +/// QsQueryConfig wraps [`Config`](crate::de::Config) and implement [`Clone`] +/// for [`FromRequest`](https://docs.rs/axum/0.5/axum/extract/trait.FromRequest.html) +/// +/// ## Example +/// +/// ```rust +/// # extern crate axum_framework as axum; +/// use serde_qs::axum::{QsQuery, QsQueryConfig, QsQueryRejection}; +/// use serde_qs::Config; +/// use axum::{ +/// response::IntoResponse, +/// routing::get, +/// Router, +/// body::Body, +/// extract::Extension, +/// http::StatusCode, +/// }; +/// use std::sync::Arc; +/// +/// #[derive(serde::Deserialize)] +/// pub struct UsersFilter { +/// id: Vec<u64>, +/// } +/// +/// async fn filter_users( +/// QsQuery(info): QsQuery<UsersFilter> +/// ) -> impl IntoResponse { +/// info.id +/// .iter() +/// .map(|i| i.to_string()) +/// .collect::<Vec<String>>() +/// .join(", ") +/// } +/// +/// fn main() { +/// let app = Router::<Body>::new() +/// .route("/users", get(filter_users)) +/// .layer(Extension(Arc::new(QsQueryConfig::new(5, false) +/// .error_handler(|err| { +/// QsQueryRejection::new(err, StatusCode::UNPROCESSABLE_ENTITY) +/// })))); +/// } +pub struct QsQueryConfig { + max_depth: usize, + strict: bool, + error_handler: Option<Arc<dyn Fn(QsError) -> QsQueryRejection + Send + Sync>>, +} + +impl QsQueryConfig { + /// Create new config wrapper + pub fn new(max_depth: usize, strict: bool) -> Self { + Self { + max_depth, + strict, + error_handler: None, + } + } + + /// Set custom error handler + pub fn error_handler<F>(mut self, f: F) -> Self + where + F: Fn(QsError) -> QsQueryRejection + Send + Sync + 'static, + { + self.error_handler = Some(Arc::new(f)); + self + } +} + +impl From<QsQueryConfig> for QsConfig { + fn from(config: QsQueryConfig) -> Self { + Self::new(config.max_depth, config.strict) + } +} + +impl Default for QsQueryConfig { + fn default() -> Self { + Self { + max_depth: 5, + strict: true, + error_handler: None, + } + } +} @@ -200,3 +200,6 @@ pub use de::{from_bytes, from_str}; pub use error::Error; #[doc(inline)] pub use ser::{to_string, to_writer, QsSerializer}; + +#[cfg(feature = "axum")] +pub mod axum; |