From 8425e7f305b4f61ba247eea6b46db75934e9d7cb Mon Sep 17 00:00:00 2001
From: Sanskar Jethi <29942790+sansyrox@users.noreply.github.com>
Date: Sun, 26 Nov 2023 01:35:37 +0530
Subject: [PATCH] fix: add a multi map data structure to query parameters
(#699)
* fix: add a multi map data structure to query parameters
* Add docstrings
* Update documentation
* fix: improve one public api
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.../api_reference/getting_started.mdx | 4 +-
.../api_reference/request_object.mdx | 6 +-
.../example_app/modeling_routes.mdx | 4 +-
integration_tests/base_routes.py | 4 +-
integration_tests/test_get_requests.py | 2 +-
robyn/robyn.pyi | 94 ++++++++++++++-
src/executors/mod.rs | 1 +
src/lib.rs | 2 +
src/types/mod.rs | 1 +
src/types/multimap.rs | 112 ++++++++++++++++++
src/types/request.rs | 25 ++--
unit_tests/test_request_object.py | 2 +-
12 files changed, 234 insertions(+), 23 deletions(-)
create mode 100644 src/types/multimap.rs
diff --git a/docs_src/src/pages/documentation/api_reference/getting_started.mdx b/docs_src/src/pages/documentation/api_reference/getting_started.mdx
index f2b72c960..59123a4ae 100644
--- a/docs_src/src/pages/documentation/api_reference/getting_started.mdx
+++ b/docs_src/src/pages/documentation/api_reference/getting_started.mdx
@@ -230,7 +230,7 @@ Batman was curious about how to access path parameters and query parameters from
```python {{ title: 'untyped' }}
@app.get("/query")
async def query_get(request):
- query_data = request["queries"]
+ query_data = request.query_params.to_dict()
return jsonify(query_data)
```
@@ -239,7 +239,7 @@ Batman was curious about how to access path parameters and query parameters from
@app.get("/query")
async def query_get(request: Request):
- query_data = request["queries"]
+ query_data = request.query_params.to_dict()
return jsonify(query_data)
```
diff --git a/docs_src/src/pages/documentation/api_reference/request_object.mdx b/docs_src/src/pages/documentation/api_reference/request_object.mdx
index 5981c533b..5f824d897 100644
--- a/docs_src/src/pages/documentation/api_reference/request_object.mdx
+++ b/docs_src/src/pages/documentation/api_reference/request_object.mdx
@@ -17,7 +17,7 @@ The request object is created in Rust side but is exposed to Python as a datacla
Attributes:
-queries (dict[str, str]): The query parameters of the request. `e.g. /user?id=123 -> {"id": "123"}`
+query_params (QueryParams): The query parameters of the request. `e.g. /user?id=123 -> {"id": [ "123" ]}`
headers (dict[str, str]): The headers of the request. `e.g. {"Content-Type": "application/json"}`
@@ -54,7 +54,7 @@ identity (Optional[Identity]): The identity of the client
@dataclass
class Request:
"""
- queries: dict[str, str]
+ query_params: QueryParams
headers: dict[str, str]
path_params: dict[str, str]
body: Union[str, bytes]
@@ -69,7 +69,7 @@ identity (Optional[Identity]): The identity of the client
@dataclass
class Request:
"""
- queries: dict[str, str]
+ query_params: QueryParams
headers: dict[str, str]
path_params: dict[str, str]
body: Union[str, bytes]
diff --git a/docs_src/src/pages/documentation/example_app/modeling_routes.mdx b/docs_src/src/pages/documentation/example_app/modeling_routes.mdx
index 3dad8ea89..d76f974b3 100644
--- a/docs_src/src/pages/documentation/example_app/modeling_routes.mdx
+++ b/docs_src/src/pages/documentation/example_app/modeling_routes.mdx
@@ -111,8 +111,8 @@ async def add_crime(request):
@app.get("/crimes")
async def get_crimes(request):
with SessionLocal() as db:
- skip = request.queries.get("skip", 0)
- limit = request.queries.get("limit", 100)
+ skip = request.query_params.get("skip", 0)
+ limit = request.query_params.get("limit", 100)
crimes = crud.get_crimes(db, skip=skip, limit=limit)
return crimes
diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py
index 7ad5b935d..b1de33b0b 100644
--- a/integration_tests/base_routes.py
+++ b/integration_tests/base_routes.py
@@ -448,13 +448,13 @@ async def file_download_async():
@app.get("/sync/queries")
def sync_queries(request: Request):
- query_data = request.queries
+ query_data = request.query_params.to_dict()
return jsonify(query_data)
@app.get("/async/queries")
async def async_query(request: Request):
- query_data = request.queries
+ query_data = request.query_params.to_dict()
return jsonify(query_data)
diff --git a/integration_tests/test_get_requests.py b/integration_tests/test_get_requests.py
index ac51df00c..d3745c04c 100644
--- a/integration_tests/test_get_requests.py
+++ b/integration_tests/test_get_requests.py
@@ -47,7 +47,7 @@ def check_response(r: Response):
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_queries(function_type: str, session):
r = get(f"/{function_type}/queries?hello=robyn")
- assert r.json() == {"hello": "robyn"}
+ assert r.json() == {"hello": ["robyn"]}
r = get(f"/{function_type}/queries")
assert r.json() == {}
diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi
index 6d5d756cc..a99d0c274 100644
--- a/robyn/robyn.pyi
+++ b/robyn/robyn.pyi
@@ -61,13 +61,103 @@ class Url:
class Identity:
claims: dict[str, str]
+@dataclass
+class QueryParams:
+ """
+ The query params object passed to the route handler.
+
+ Attributes:
+ queries (dict[str, list[str]]): The query parameters of the request. e.g. /user?id=123 -> {"id": "123"}
+ """
+
+ queries: dict[str, list[str]]
+
+ def set(self, key: str, value: str) -> None:
+ """
+ Sets the value of the query parameter with the given key.
+ If the key already exists, the value will be appended to the list of values.
+
+ Args:
+ key (str): The key of the query parameter
+ value (str): The value of the query parameter
+ """
+ pass
+
+ def get(self, key: str, default: Optional[str]) -> Optional[str]:
+ """
+ Gets the last value of the query parameter with the given key.
+
+ Args:
+ key (str): The key of the query parameter
+ default (Optional[str]): The default value if the key does not exist
+ """
+ pass
+
+ def empty(self) -> bool:
+ """
+ Returns:
+ True if the query params are empty, False otherwise
+ """
+ pass
+
+ def contains(self, key: str) -> bool:
+ """
+ Returns:
+ True if the query params contain the key, False otherwise
+
+ Args:
+ key (str): The key of the query parameter
+ """
+ pass
+
+ def get_first(self, key: str) -> Optional[str]:
+ """
+ Gets the first value of the query parameter with the given key.
+
+ Args:
+ key (str): The key of the query parameter
+
+ """
+ pass
+
+ def get_all(self, key: str) -> Optional[list[str]]:
+ """
+ Gets all the values of the query parameter with the given key.
+
+ Args:
+ key (str): The key of the query parameter
+ """
+ pass
+
+ def extend(self, other: QueryParams) -> None:
+ """
+ Extends the query params with the other query params.
+
+ Args:
+ other (QueryParams): The other QueryParams object
+ """
+ pass
+
+ def to_dict(self) -> dict[str, list[str]]:
+ """
+ Returns:
+ The query params as a dictionary
+ """
+ pass
+
+ def __contains__(self, key: str) -> bool:
+ pass
+
+ def __repr__(self) -> str:
+ pass
+
@dataclass
class Request:
"""
The request object passed to the route handler.
Attributes:
- queries (dict[str, str]): The query parameters of the request. e.g. /user?id=123 -> {"id": "123"}
+ query_params (QueryParams): The query parameters of the request. e.g. /user?id=123 -> {"id": "123"}
headers (dict[str, str]): The headers of the request. e.g. {"Content-Type": "application/json"}
params (dict[str, str]): The parameters of the request. e.g. /user/:id -> {"id": "123"}
body (Union[str, bytes]): The body of the request. If the request is a JSON, it will be a dict.
@@ -75,7 +165,7 @@ class Request:
ip_addr (Optional[str]): The IP Address of the client
"""
- queries: dict[str, str]
+ query_params: QueryParams
headers: dict[str, str]
path_params: dict[str, str]
body: Union[str, bytes]
diff --git a/src/executors/mod.rs b/src/executors/mod.rs
index 23a39ad32..ffb37ee36 100644
--- a/src/executors/mod.rs
+++ b/src/executors/mod.rs
@@ -36,6 +36,7 @@ where
// Execute the middleware function
// type T can be either Request (before middleware) or Response (after middleware)
// Return type can either be a Request or a Response, we wrap it inside an enum for easier handling
+#[inline]
pub async fn execute_middleware_function(
input: &T,
function: &FunctionInfo,
diff --git a/src/lib.rs b/src/lib.rs
index 04e91d266..10cc9ffc2 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -17,6 +17,7 @@ use pyo3::{exceptions::PyValueError, prelude::*};
use types::{
function_info::{FunctionInfo, MiddlewareType},
identity::Identity,
+ multimap::QueryParams,
request::PyRequest,
response::PyResponse,
HttpMethod, Url,
@@ -62,6 +63,7 @@ pub fn robyn(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::()?;
m.add_class::()?;
m.add_class::()?;
+ m.add_class::()?;
m.add_class::()?;
m.add_class::()?;
diff --git a/src/types/mod.rs b/src/types/mod.rs
index ca58e831d..53fc9656b 100644
--- a/src/types/mod.rs
+++ b/src/types/mod.rs
@@ -6,6 +6,7 @@ use pyo3::{
pub mod function_info;
pub mod identity;
+pub mod multimap;
pub mod request;
pub mod response;
diff --git a/src/types/multimap.rs b/src/types/multimap.rs
new file mode 100644
index 000000000..b0f4e2a62
--- /dev/null
+++ b/src/types/multimap.rs
@@ -0,0 +1,112 @@
+use log::debug;
+use pyo3::prelude::*;
+use pyo3::types::{PyDict, PyList};
+use std::collections::HashMap;
+
+// Custom Multimap class
+#[pyclass(name = "QueryParams")]
+#[derive(Clone, Debug, Default)]
+pub struct QueryParams {
+ pub queries: HashMap>,
+}
+
+#[pymethods]
+impl QueryParams {
+ #[new]
+ pub fn new() -> Self {
+ QueryParams {
+ queries: HashMap::new(),
+ }
+ }
+
+ pub fn set(&mut self, key: String, value: String) {
+ debug!("Setting key: {} to value: {}", key, value);
+ self.queries.entry(key).or_insert_with(Vec::new).push(value);
+ debug!("Multimap: {:?}", self.queries);
+ }
+
+ pub fn get(&self, key: String, default: Option) -> Option {
+ match self.queries.get(&key) {
+ Some(values) => values.last().cloned(),
+ None => default,
+ }
+ }
+
+ pub fn get_first(&self, key: String) -> Option {
+ match self.queries.get(&key) {
+ Some(values) => values.first().cloned(),
+ None => None,
+ }
+ }
+
+ pub fn empty(&self) -> bool {
+ self.queries.is_empty()
+ }
+
+ pub fn contains(&self, key: String) -> bool {
+ self.queries.contains_key(&key)
+ }
+
+ pub fn get_all(&self, key: String) -> Option> {
+ self.queries.get(&key).cloned()
+ }
+
+ pub fn extend(&mut self, other: &mut Self) {
+ for (key, values) in other.queries.iter_mut() {
+ for value in values.iter_mut() {
+ self.set(key.clone(), value.clone());
+ }
+ }
+ }
+
+ pub fn to_dict(&self, py: Python) -> PyResult> {
+ let dict = PyDict::new(py);
+ for (key, values) in self.queries.iter() {
+ let values = PyList::new(py, values.iter());
+ dict.set_item(key, values)?;
+ }
+ Ok(dict.into())
+ }
+
+ pub fn __contains__(&self, key: String) -> bool {
+ self.queries.contains_key(&key)
+ }
+
+ pub fn __repr__(&self) -> String {
+ format!("{:?}", self.queries)
+ }
+}
+
+impl QueryParams {
+ pub fn from_hashmap(map: HashMap>) -> Self {
+ let mut multimap = QueryParams::new();
+ for (key, values) in map {
+ for value in values {
+ multimap.set(key.clone(), value);
+ }
+ }
+ multimap
+ }
+
+ pub fn from_dict(dict: &PyDict) -> Self {
+ let mut multimap = QueryParams::new();
+ for (key, value) in dict.iter() {
+ let key = key.extract::().unwrap();
+ let value = value.extract::().unwrap();
+ multimap.set(key, value);
+ }
+ multimap
+ }
+
+ pub fn contains_key(&self, key: &str) -> bool {
+ self.queries.contains_key(key)
+ }
+
+ pub fn insert(&mut self, key: String, value: Vec) {
+ self.queries.insert(key, value);
+ }
+
+ pub fn get_mut(&mut self, key: &str) -> Option<&Vec> {
+ self.queries.get(key)
+ }
+}
diff --git a/src/types/request.rs b/src/types/request.rs
index 3d1b151fc..cf1c3e385 100644
--- a/src/types/request.rs
+++ b/src/types/request.rs
@@ -6,11 +6,11 @@ use std::collections::HashMap;
use crate::types::{check_body_type, get_body_from_pyobject, Url};
-use super::identity::Identity;
+use super::{identity::Identity, multimap::QueryParams};
#[derive(Default, Debug, Clone, FromPyObject)]
pub struct Request {
- pub queries: HashMap,
+ pub query_params: QueryParams,
pub headers: HashMap,
pub method: String,
pub path_params: HashMap,
@@ -24,7 +24,7 @@ pub struct Request {
impl ToPyObject for Request {
fn to_object(&self, py: Python) -> PyObject {
- let queries = self.queries.clone().into_py(py).extract(py).unwrap();
+ let query_params = self.query_params.clone();
let headers = self.headers.clone().into_py(py).extract(py).unwrap();
let path_params = self.path_params.clone().into_py(py).extract(py).unwrap();
let body = match String::from_utf8(self.body.clone()) {
@@ -33,7 +33,7 @@ impl ToPyObject for Request {
};
let request = PyRequest {
- queries,
+ query_params,
path_params,
headers,
body,
@@ -52,12 +52,15 @@ impl Request {
body: Bytes,
global_headers: &DashMap,
) -> Self {
- let mut queries = HashMap::new();
+ let mut query_params: QueryParams = QueryParams::new();
if !req.query_string().is_empty() {
let split = req.query_string().split('&');
for s in split {
let path_params = s.split_once('=').unwrap_or((s, ""));
- queries.insert(path_params.0.to_string(), path_params.1.to_string());
+ let key = path_params.0.to_string();
+ let value = path_params.1.to_string();
+
+ query_params.set(key, value);
}
}
let headers = req
@@ -79,7 +82,7 @@ impl Request {
let ip_addr = req.peer_addr().map(|val| val.ip().to_string());
Self {
- queries,
+ query_params,
headers,
method: req.method().as_str().to_owned(),
path_params: HashMap::new(),
@@ -95,7 +98,7 @@ impl Request {
#[derive(Clone)]
pub struct PyRequest {
#[pyo3(get, set)]
- pub queries: Py,
+ pub query_params: QueryParams,
#[pyo3(get, set)]
pub headers: Py,
#[pyo3(get, set)]
@@ -117,7 +120,7 @@ impl PyRequest {
#[new]
#[allow(clippy::too_many_arguments)]
pub fn new(
- queries: Py,
+ query_params: &PyDict,
headers: Py,
path_params: Py,
body: Py,
@@ -126,8 +129,10 @@ impl PyRequest {
identity: Option,
ip_addr: Option,
) -> Self {
+ let query_params = QueryParams::from_dict(query_params);
+
Self {
- queries,
+ query_params,
headers,
path_params,
identity,
diff --git a/unit_tests/test_request_object.py b/unit_tests/test_request_object.py
index a4f731881..fccc7fa8a 100644
--- a/unit_tests/test_request_object.py
+++ b/unit_tests/test_request_object.py
@@ -8,7 +8,7 @@ def test_request_object():
path="/user",
)
request = Request(
- queries={},
+ query_params={},
headers={"Content-Type": "application/json"},
path_params={},
body="",