From 096d851f6d878c2c727f530ccc5749b4d8be2ab0 Mon Sep 17 00:00:00 2001
From: Rudi Floren <rudi.floren@gmail.com>
Date: Thu, 2 Mar 2023 06:01:26 +0100
Subject: Improve support unit types (#76)

* fix unit type serialization issue

`()` and `A` returned a no key error previously. This is very
unergonimic if you just have a trait bound for Serialize and want to
generate an empty querystring `?`

* add support for deserializing unit structs

* Comment updates for serializer methods
---
 src/de/mod.rs             | 17 +++++++++++++++--
 src/ser.rs                | 43 +++++++++++++++++++++----------------------
 tests/test_deserialize.rs | 23 +++++++++++++++++++++++
 tests/test_serialize.rs   | 38 ++++++++++++++++++++++++++++++++++++++
 4 files changed, 97 insertions(+), 24 deletions(-)

diff --git a/src/de/mod.rs b/src/de/mod.rs
index fd34131..ef9cd76 100644
--- a/src/de/mod.rs
+++ b/src/de/mod.rs
@@ -221,10 +221,14 @@ impl<'a> QsDeserializer<'a> {
 impl<'de> de::Deserializer<'de> for QsDeserializer<'de> {
     type Error = Error;
 
-    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
+    fn deserialize_any<V>(mut self, visitor: V) -> Result<V::Value>
     where
         V: de::Visitor<'de>,
     {
+        if self.iter.next().is_none() {
+            return visitor.visit_unit();
+        }
+
         Err(Error::top_level("primitive"))
     }
 
@@ -572,6 +576,16 @@ impl<'de> de::Deserializer<'de> for LevelDeserializer<'de> {
         }
     }
 
+    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
+    where
+        V: de::Visitor<'de>,
+    {
+        match self.0 {
+            Level::Flat(ref x) if x == "" => visitor.visit_unit(),
+            _ => Err(de::Error::custom("expected unit".to_owned())),
+        }
+    }
+
     fn deserialize_enum<V>(
         self,
         name: &'static str,
@@ -646,7 +660,6 @@ impl<'de> de::Deserializer<'de> for LevelDeserializer<'de> {
         string
         bytes
         byte_buf
-        unit
         unit_struct
         // newtype_struct
         tuple_struct
diff --git a/src/ser.rs b/src/ser.rs
index b1e46d7..607fa0d 100644
--- a/src/ser.rs
+++ b/src/ser.rs
@@ -156,12 +156,10 @@ impl<'a, W: Write> ser::Serializer for &'a mut Serializer<W> {
         self.as_qs_serializer().serialize_unit()
     }
 
-    /// Returns an error.
     fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok> {
         self.as_qs_serializer().serialize_unit_struct(name)
     }
 
-    /// Returns an error.
     fn serialize_unit_variant(
         self,
         name: &'static str,
@@ -172,7 +170,6 @@ impl<'a, W: Write> ser::Serializer for &'a mut Serializer<W> {
             .serialize_unit_variant(name, variant_index, variant)
     }
 
-    /// Returns an error.
     fn serialize_newtype_struct<T: ?Sized + ser::Serialize>(
         self,
         name: &'static str,
@@ -182,7 +179,6 @@ impl<'a, W: Write> ser::Serializer for &'a mut Serializer<W> {
             .serialize_newtype_struct(name, value)
     }
 
-    /// Returns an error.
     fn serialize_newtype_variant<T: ?Sized + ser::Serialize>(
         self,
         name: &'static str,
@@ -202,7 +198,6 @@ impl<'a, W: Write> ser::Serializer for &'a mut Serializer<W> {
         self.as_qs_serializer().serialize_some(value)
     }
 
-    /// Returns an error.
     fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq> {
         self.as_qs_serializer().serialize_seq(len)
     }
@@ -211,7 +206,6 @@ impl<'a, W: Write> ser::Serializer for &'a mut Serializer<W> {
         self.as_qs_serializer().serialize_tuple(len)
     }
 
-    /// Returns an error.
     fn serialize_tuple_struct(
         self,
         name: &'static str,
@@ -281,14 +275,11 @@ impl<'a, W: 'a + Write> QsSerializer<'a, W> {
 
     fn write_value(&mut self, value: &[u8]) -> Result<()> {
         if let Some(ref key) = self.key {
+            let amp = !self.first.swap(false, Ordering::Relaxed);
             write!(
                 self.writer,
                 "{}{}={}",
-                if self.first.swap(false, Ordering::Relaxed) {
-                    ""
-                } else {
-                    "&"
-                },
+                amp.then_some("&").unwrap_or_default(),
                 key,
                 percent_encode(value, QS_ENCODE_SET)
                     .map(replace_space)
@@ -300,6 +291,22 @@ impl<'a, W: 'a + Write> QsSerializer<'a, W> {
         }
     }
 
+    fn write_unit(&mut self) -> Result<()> {
+        let amp = !self.first.swap(false, Ordering::Relaxed);
+        if let Some(ref key) = self.key {
+            write!(
+                self.writer,
+                "{}{}=",
+                amp.then_some("&").unwrap_or_default(),
+                key,
+            )
+            .map_err(Error::from)
+        } else {
+            // For top level unit types
+            write!(self.writer, "{}", amp.then_some("&").unwrap_or_default(),).map_err(Error::from)
+        }
+    }
+
     /// Creates a new `QsSerializer` with a distinct key, but `writer` and
     ///`first` referring to the original data.
     fn new_from_ref<'b: 'a>(other: &'a mut QsSerializer<'b, W>) -> QsSerializer<'a, W> {
@@ -351,15 +358,13 @@ impl<'a, W: Write> ser::Serializer for QsSerializer<'a, W> {
     }
 
     fn serialize_unit(mut self) -> Result<Self::Ok> {
-        self.write_value(&[])
+        self.write_unit()
     }
 
-    /// Returns an error.
-    fn serialize_unit_struct(mut self, name: &'static str) -> Result<Self::Ok> {
-        self.write_value(name.as_bytes())
+    fn serialize_unit_struct(mut self, _: &'static str) -> Result<Self::Ok> {
+        self.write_unit()
     }
 
-    /// Returns an error.
     fn serialize_unit_variant(
         mut self,
         _name: &'static str,
@@ -369,7 +374,6 @@ impl<'a, W: Write> ser::Serializer for QsSerializer<'a, W> {
         self.write_value(variant.as_bytes())
     }
 
-    /// Returns an error.
     fn serialize_newtype_struct<T: ?Sized + ser::Serialize>(
         self,
         _name: &'static str,
@@ -378,7 +382,6 @@ impl<'a, W: Write> ser::Serializer for QsSerializer<'a, W> {
         value.serialize(self)
     }
 
-    /// Returns an error.
     fn serialize_newtype_variant<T: ?Sized + ser::Serialize>(
         mut self,
         _name: &'static str,
@@ -395,11 +398,9 @@ impl<'a, W: Write> ser::Serializer for QsSerializer<'a, W> {
     }
 
     fn serialize_some<T: ?Sized + ser::Serialize>(self, value: &T) -> Result<Self::Ok> {
-        // Err(Error::Unsupported)
         value.serialize(self)
     }
 
-    /// Returns an error.
     fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq> {
         Ok(QsSeq(self, 0))
     }
@@ -408,7 +409,6 @@ impl<'a, W: Write> ser::Serializer for QsSerializer<'a, W> {
         Ok(QsSeq(self, 0))
     }
 
-    /// Returns an error.
     fn serialize_tuple_struct(
         self,
         _name: &'static str,
@@ -424,7 +424,6 @@ impl<'a, W: Write> ser::Serializer for QsSerializer<'a, W> {
         variant: &'static str,
         _len: usize,
     ) -> Result<Self::SerializeTupleVariant> {
-        // self.write(variant)?;
         self.extend_key(variant);
         Ok(QsSeq(self, 0))
     }
diff --git a/tests/test_deserialize.rs b/tests/test_deserialize.rs
index 92e5f53..b2f9dc8 100644
--- a/tests/test_deserialize.rs
+++ b/tests/test_deserialize.rs
@@ -699,3 +699,26 @@ fn deserialize_map_with_int_keys() {
     serde_qs::from_str::<Mapping>("mapping[1]=2&mapping[1]=4")
         .expect_err("should error with repeated key");
 }
+
+#[test]
+fn deserialize_unit_types() {
+    #[derive(Debug, Deserialize, PartialEq)]
+    struct A;
+    #[derive(Debug, Deserialize, PartialEq)]
+    struct B<'a> {
+        t: (),
+        a: &'a str,
+    }
+
+    let test: () = serde_qs::from_str("").unwrap();
+    assert_eq!(test, ());
+
+    let test: A = serde_qs::from_str("").unwrap();
+    assert_eq!(test, A);
+
+    let test: B = serde_qs::from_str("a=test&t=").unwrap();
+    assert_eq!(test, B { t: (), a: "test" });
+
+    let test: B = serde_qs::from_str("t=&a=test").unwrap();
+    assert_eq!(test, B { t: (), a: "test" });
+}
diff --git a/tests/test_serialize.rs b/tests/test_serialize.rs
index b7c4a21..0b9f0c6 100644
--- a/tests/test_serialize.rs
+++ b/tests/test_serialize.rs
@@ -236,3 +236,41 @@ fn test_serializer() {
 
     assert_eq!(writer, b"a[0]=3&a[1]=2&b=a");
 }
+
+#[test]
+fn test_serializer_unit() {
+    use serde::Serialize;
+    #[derive(Serialize)]
+    struct A;
+    #[derive(Serialize)]
+    struct B {
+        t: (),
+    }
+
+    let mut writer = Vec::new();
+    {
+        let serializer = &mut qs::Serializer::new(&mut writer);
+        let q = ();
+        q.serialize(serializer).unwrap();
+    }
+
+    assert_eq!(writer, b"", "we are testing ()");
+    writer.clear();
+
+    {
+        let serializer = &mut qs::Serializer::new(&mut writer);
+        let q = A;
+        q.serialize(serializer).unwrap();
+    }
+
+    assert_eq!(writer, b"", "we are testing A");
+    writer.clear();
+
+    {
+        let serializer = &mut qs::Serializer::new(&mut writer);
+        let q = B { t: () };
+        q.serialize(serializer).unwrap();
+    }
+
+    assert_eq!(writer, b"t=", "we are testing B{{t: ()}}");
+}
-- 
cgit v1.2.3