//! Functionality for using `serde_qs` with `axum`. //! //! Enable with the `axum` feature. use axum_framework as axum; use std::sync::Arc; use crate::de::Config as QsConfig; use crate::error::Error as QsError; use axum::{ extract::{Extension, FromRequestParts}, http::StatusCode, response::{IntoResponse, Response}, BoxError, Error, }; #[derive(Clone, Copy, Default)] /// Extract typed information from from the request's query. /// /// ## Example /// /// ```rust /// # extern crate axum_framework as axum; /// use serde_qs::axum::QsQuery; /// use serde_qs::Config; /// use axum::{response::IntoResponse, routing::get, Router, body::Body}; /// /// #[derive(serde::Deserialize)] /// pub struct UsersFilter { /// id: Vec, /// } /// /// async fn filter_users( /// QsQuery(info): QsQuery /// ) -> impl IntoResponse { /// info.id /// .iter() /// .map(|i| i.to_string()) /// .collect::>() /// .join(", ") /// } /// /// fn main() { /// let app = Router::<()>::new() /// .route("/users", get(filter_users)); /// } pub struct QsQuery(pub T); impl std::ops::Deref for QsQuery { type Target = T; fn deref(&self) -> &Self::Target { &self.0 } } impl std::fmt::Display for QsQuery { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { self.0.fmt(f) } } impl std::fmt::Debug for QsQuery { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { self.0.fmt(f) } } impl FromRequestParts for QsQuery where T: serde::de::DeserializeOwned, S: Send + Sync, { type Rejection = QsQueryRejection; async fn from_request_parts( parts: &mut axum::http::request::Parts, state: &S, ) -> Result { let Extension(qs_config) = Extension::::from_request_parts(parts, state) .await .unwrap_or_else(|_| Extension(QsQueryConfig::default())); let error_handler = qs_config.error_handler.clone(); let config: QsConfig = qs_config.into(); let query = parts.uri.query().unwrap_or_default(); match config.deserialize_str::(query) { Ok(value) => Ok(QsQuery(value)), Err(err) => match error_handler { Some(handler) => Err((handler)(err)), None => Err(QsQueryRejection::new(err, StatusCode::BAD_REQUEST)), }, } } } /// Extractor that differentiates between the absence and presence of the query string /// using `Option`. Absence of query string encoded as `None`. Otherwise, it behaves /// identical to the `QsQuery`. /// /// ## Example /// /// ```rust /// # extern crate axum_framework as axum; /// use serde_qs::axum::OptionalQsQuery; /// use serde_qs::Config; /// use axum::{response::IntoResponse, routing::get, Router, body::Body}; /// /// #[derive(serde::Deserialize)] /// pub struct UsersFilter { /// id: Vec, /// } /// /// async fn filter_users( /// OptionalQsQuery(info): OptionalQsQuery /// ) -> impl IntoResponse { /// match info { /// Some(info) => todo!("Select users based on query string"), /// None => { todo!("No query string provided")} /// } /// } /// /// fn main() { /// let app = Router::<()>::new() /// .route("/users", get(filter_users)); /// } #[derive(Clone, Copy, Default)] pub struct OptionalQsQuery(pub Option); impl std::ops::Deref for OptionalQsQuery { type Target = Option; fn deref(&self) -> &Self::Target { &self.0 } } impl std::ops::DerefMut for OptionalQsQuery { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } impl std::fmt::Debug for OptionalQsQuery { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { self.0.fmt(f) } } impl FromRequestParts for OptionalQsQuery where T: serde::de::DeserializeOwned, S: Send + Sync, { type Rejection = QsQueryRejection; async fn from_request_parts( parts: &mut axum::http::request::Parts, state: &S, ) -> Result { let Extension(qs_config) = Extension::::from_request_parts(parts, state) .await .unwrap_or_else(|_| Extension(QsQueryConfig::default())); if let Some(query) = parts.uri.query() { let error_handler = qs_config.error_handler.clone(); let config: QsConfig = qs_config.into(); config .deserialize_str::(query) .map(|query| OptionalQsQuery(Some(query))) .map_err(|err| match error_handler { Some(handler) => handler(err), None => QsQueryRejection::new(err, StatusCode::BAD_REQUEST), }) } else { Ok(OptionalQsQuery(None)) } } } #[derive(Debug)] /// Rejection type for extractors that deserialize query strings pub struct QsQueryRejection { error: axum::Error, status: StatusCode, } impl std::fmt::Display for QsQueryRejection { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Failed to deserialize query string. Error: {}", self.error, ) } } impl QsQueryRejection { /// Create new rejection pub fn new(error: E, status: StatusCode) -> Self where E: Into, { QsQueryRejection { error: Error::new(error), status, } } } impl IntoResponse for QsQueryRejection { fn into_response(self) -> Response { let mut res = self.to_string().into_response(); *res.status_mut() = self.status; res } } impl std::error::Error for QsQueryRejection { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { Some(&self.error) } } #[derive(Clone)] /// Query extractor configuration /// /// QsQueryConfig wraps [`Config`](crate::de::Config) and implement [`Clone`] /// for [`FromRequest`](https://docs.rs/axum/0.5/axum/extract/trait.FromRequest.html) /// /// ## Example /// /// ```rust /// # extern crate axum_framework as axum; /// use serde_qs::axum::{QsQuery, QsQueryConfig, QsQueryRejection}; /// use serde_qs::Config; /// use axum::{ /// response::IntoResponse, /// routing::get, /// Router, /// body::Body, /// extract::Extension, /// http::StatusCode, /// }; /// use std::sync::Arc; /// /// #[derive(serde::Deserialize)] /// pub struct UsersFilter { /// id: Vec, /// } /// /// async fn filter_users( /// QsQuery(info): QsQuery /// ) -> impl IntoResponse { /// info.id /// .iter() /// .map(|i| i.to_string()) /// .collect::>() /// .join(", ") /// } /// /// fn main() { /// let app = Router::<()>::new() /// .route("/users", get(filter_users)) /// .layer(Extension(QsQueryConfig::new(5, false) /// .error_handler(|err| { /// QsQueryRejection::new(err, StatusCode::UNPROCESSABLE_ENTITY) /// }))); /// } pub struct QsQueryConfig { max_depth: usize, strict: bool, error_handler: Option QsQueryRejection + Send + Sync>>, } impl QsQueryConfig { /// Create new config wrapper pub fn new(max_depth: usize, strict: bool) -> Self { Self { max_depth, strict, error_handler: None, } } /// Set custom error handler pub fn error_handler(mut self, f: F) -> Self where F: Fn(QsError) -> QsQueryRejection + Send + Sync + 'static, { self.error_handler = Some(Arc::new(f)); self } } impl From for QsConfig { fn from(config: QsQueryConfig) -> Self { Self::new(config.max_depth, config.strict) } } impl Default for QsQueryConfig { fn default() -> Self { Self { max_depth: 5, strict: true, error_handler: None, } } }