From 8425e7f305b4f61ba247eea6b46db75934e9d7cb Mon Sep 17 00:00:00 2001 From: Sanskar Jethi <29942790+sansyrox@users.noreply.github.com> Date: Sun, 26 Nov 2023 01:35:37 +0530 Subject: [PATCH] fix: add a multi map data structure to query parameters (#699) * fix: add a multi map data structure to query parameters * Add docstrings * Update documentation * fix: improve one public api * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../api_reference/getting_started.mdx | 4 +- .../api_reference/request_object.mdx | 6 +- .../example_app/modeling_routes.mdx | 4 +- integration_tests/base_routes.py | 4 +- integration_tests/test_get_requests.py | 2 +- robyn/robyn.pyi | 94 ++++++++++++++- src/executors/mod.rs | 1 + src/lib.rs | 2 + src/types/mod.rs | 1 + src/types/multimap.rs | 112 ++++++++++++++++++ src/types/request.rs | 25 ++-- unit_tests/test_request_object.py | 2 +- 12 files changed, 234 insertions(+), 23 deletions(-) create mode 100644 src/types/multimap.rs diff --git a/docs_src/src/pages/documentation/api_reference/getting_started.mdx b/docs_src/src/pages/documentation/api_reference/getting_started.mdx index f2b72c960..59123a4ae 100644 --- a/docs_src/src/pages/documentation/api_reference/getting_started.mdx +++ b/docs_src/src/pages/documentation/api_reference/getting_started.mdx @@ -230,7 +230,7 @@ Batman was curious about how to access path parameters and query parameters from ```python {{ title: 'untyped' }} @app.get("/query") async def query_get(request): - query_data = request["queries"] + query_data = request.query_params.to_dict() return jsonify(query_data) ``` @@ -239,7 +239,7 @@ Batman was curious about how to access path parameters and query parameters from @app.get("/query") async def query_get(request: Request): - query_data = request["queries"] + query_data = request.query_params.to_dict() return jsonify(query_data) ``` diff --git a/docs_src/src/pages/documentation/api_reference/request_object.mdx b/docs_src/src/pages/documentation/api_reference/request_object.mdx index 5981c533b..5f824d897 100644 --- a/docs_src/src/pages/documentation/api_reference/request_object.mdx +++ b/docs_src/src/pages/documentation/api_reference/request_object.mdx @@ -17,7 +17,7 @@ The request object is created in Rust side but is exposed to Python as a datacla Attributes:
  • -queries (dict[str, str]): The query parameters of the request. `e.g. /user?id=123 -> {"id": "123"}` +query_params (QueryParams): The query parameters of the request. `e.g. /user?id=123 -> {"id": [ "123" ]}`
  • headers (dict[str, str]): The headers of the request. `e.g. {"Content-Type": "application/json"}` @@ -54,7 +54,7 @@ identity (Optional[Identity]): The identity of the client @dataclass class Request: """ - queries: dict[str, str] + query_params: QueryParams headers: dict[str, str] path_params: dict[str, str] body: Union[str, bytes] @@ -69,7 +69,7 @@ identity (Optional[Identity]): The identity of the client @dataclass class Request: """ - queries: dict[str, str] + query_params: QueryParams headers: dict[str, str] path_params: dict[str, str] body: Union[str, bytes] diff --git a/docs_src/src/pages/documentation/example_app/modeling_routes.mdx b/docs_src/src/pages/documentation/example_app/modeling_routes.mdx index 3dad8ea89..d76f974b3 100644 --- a/docs_src/src/pages/documentation/example_app/modeling_routes.mdx +++ b/docs_src/src/pages/documentation/example_app/modeling_routes.mdx @@ -111,8 +111,8 @@ async def add_crime(request): @app.get("/crimes") async def get_crimes(request): with SessionLocal() as db: - skip = request.queries.get("skip", 0) - limit = request.queries.get("limit", 100) + skip = request.query_params.get("skip", 0) + limit = request.query_params.get("limit", 100) crimes = crud.get_crimes(db, skip=skip, limit=limit) return crimes diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 7ad5b935d..b1de33b0b 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -448,13 +448,13 @@ async def file_download_async(): @app.get("/sync/queries") def sync_queries(request: Request): - query_data = request.queries + query_data = request.query_params.to_dict() return jsonify(query_data) @app.get("/async/queries") async def async_query(request: Request): - query_data = request.queries + query_data = request.query_params.to_dict() return jsonify(query_data) diff --git a/integration_tests/test_get_requests.py b/integration_tests/test_get_requests.py index ac51df00c..d3745c04c 100644 --- a/integration_tests/test_get_requests.py +++ b/integration_tests/test_get_requests.py @@ -47,7 +47,7 @@ def check_response(r: Response): @pytest.mark.parametrize("function_type", ["sync", "async"]) def test_queries(function_type: str, session): r = get(f"/{function_type}/queries?hello=robyn") - assert r.json() == {"hello": "robyn"} + assert r.json() == {"hello": ["robyn"]} r = get(f"/{function_type}/queries") assert r.json() == {} diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index 6d5d756cc..a99d0c274 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -61,13 +61,103 @@ class Url: class Identity: claims: dict[str, str] +@dataclass +class QueryParams: + """ + The query params object passed to the route handler. + + Attributes: + queries (dict[str, list[str]]): The query parameters of the request. e.g. /user?id=123 -> {"id": "123"} + """ + + queries: dict[str, list[str]] + + def set(self, key: str, value: str) -> None: + """ + Sets the value of the query parameter with the given key. + If the key already exists, the value will be appended to the list of values. + + Args: + key (str): The key of the query parameter + value (str): The value of the query parameter + """ + pass + + def get(self, key: str, default: Optional[str]) -> Optional[str]: + """ + Gets the last value of the query parameter with the given key. + + Args: + key (str): The key of the query parameter + default (Optional[str]): The default value if the key does not exist + """ + pass + + def empty(self) -> bool: + """ + Returns: + True if the query params are empty, False otherwise + """ + pass + + def contains(self, key: str) -> bool: + """ + Returns: + True if the query params contain the key, False otherwise + + Args: + key (str): The key of the query parameter + """ + pass + + def get_first(self, key: str) -> Optional[str]: + """ + Gets the first value of the query parameter with the given key. + + Args: + key (str): The key of the query parameter + + """ + pass + + def get_all(self, key: str) -> Optional[list[str]]: + """ + Gets all the values of the query parameter with the given key. + + Args: + key (str): The key of the query parameter + """ + pass + + def extend(self, other: QueryParams) -> None: + """ + Extends the query params with the other query params. + + Args: + other (QueryParams): The other QueryParams object + """ + pass + + def to_dict(self) -> dict[str, list[str]]: + """ + Returns: + The query params as a dictionary + """ + pass + + def __contains__(self, key: str) -> bool: + pass + + def __repr__(self) -> str: + pass + @dataclass class Request: """ The request object passed to the route handler. Attributes: - queries (dict[str, str]): The query parameters of the request. e.g. /user?id=123 -> {"id": "123"} + query_params (QueryParams): The query parameters of the request. e.g. /user?id=123 -> {"id": "123"} headers (dict[str, str]): The headers of the request. e.g. {"Content-Type": "application/json"} params (dict[str, str]): The parameters of the request. e.g. /user/:id -> {"id": "123"} body (Union[str, bytes]): The body of the request. If the request is a JSON, it will be a dict. @@ -75,7 +165,7 @@ class Request: ip_addr (Optional[str]): The IP Address of the client """ - queries: dict[str, str] + query_params: QueryParams headers: dict[str, str] path_params: dict[str, str] body: Union[str, bytes] diff --git a/src/executors/mod.rs b/src/executors/mod.rs index 23a39ad32..ffb37ee36 100644 --- a/src/executors/mod.rs +++ b/src/executors/mod.rs @@ -36,6 +36,7 @@ where // Execute the middleware function // type T can be either Request (before middleware) or Response (after middleware) // Return type can either be a Request or a Response, we wrap it inside an enum for easier handling +#[inline] pub async fn execute_middleware_function( input: &T, function: &FunctionInfo, diff --git a/src/lib.rs b/src/lib.rs index 04e91d266..10cc9ffc2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ use pyo3::{exceptions::PyValueError, prelude::*}; use types::{ function_info::{FunctionInfo, MiddlewareType}, identity::Identity, + multimap::QueryParams, request::PyRequest, response::PyResponse, HttpMethod, Url, @@ -62,6 +63,7 @@ pub fn robyn(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/types/mod.rs b/src/types/mod.rs index ca58e831d..53fc9656b 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -6,6 +6,7 @@ use pyo3::{ pub mod function_info; pub mod identity; +pub mod multimap; pub mod request; pub mod response; diff --git a/src/types/multimap.rs b/src/types/multimap.rs new file mode 100644 index 000000000..b0f4e2a62 --- /dev/null +++ b/src/types/multimap.rs @@ -0,0 +1,112 @@ +use log::debug; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList}; +use std::collections::HashMap; + +// Custom Multimap class +#[pyclass(name = "QueryParams")] +#[derive(Clone, Debug, Default)] +pub struct QueryParams { + pub queries: HashMap>, +} + +#[pymethods] +impl QueryParams { + #[new] + pub fn new() -> Self { + QueryParams { + queries: HashMap::new(), + } + } + + pub fn set(&mut self, key: String, value: String) { + debug!("Setting key: {} to value: {}", key, value); + self.queries.entry(key).or_insert_with(Vec::new).push(value); + debug!("Multimap: {:?}", self.queries); + } + + pub fn get(&self, key: String, default: Option) -> Option { + match self.queries.get(&key) { + Some(values) => values.last().cloned(), + None => default, + } + } + + pub fn get_first(&self, key: String) -> Option { + match self.queries.get(&key) { + Some(values) => values.first().cloned(), + None => None, + } + } + + pub fn empty(&self) -> bool { + self.queries.is_empty() + } + + pub fn contains(&self, key: String) -> bool { + self.queries.contains_key(&key) + } + + pub fn get_all(&self, key: String) -> Option> { + self.queries.get(&key).cloned() + } + + pub fn extend(&mut self, other: &mut Self) { + for (key, values) in other.queries.iter_mut() { + for value in values.iter_mut() { + self.set(key.clone(), value.clone()); + } + } + } + + pub fn to_dict(&self, py: Python) -> PyResult> { + let dict = PyDict::new(py); + for (key, values) in self.queries.iter() { + let values = PyList::new(py, values.iter()); + dict.set_item(key, values)?; + } + Ok(dict.into()) + } + + pub fn __contains__(&self, key: String) -> bool { + self.queries.contains_key(&key) + } + + pub fn __repr__(&self) -> String { + format!("{:?}", self.queries) + } +} + +impl QueryParams { + pub fn from_hashmap(map: HashMap>) -> Self { + let mut multimap = QueryParams::new(); + for (key, values) in map { + for value in values { + multimap.set(key.clone(), value); + } + } + multimap + } + + pub fn from_dict(dict: &PyDict) -> Self { + let mut multimap = QueryParams::new(); + for (key, value) in dict.iter() { + let key = key.extract::().unwrap(); + let value = value.extract::().unwrap(); + multimap.set(key, value); + } + multimap + } + + pub fn contains_key(&self, key: &str) -> bool { + self.queries.contains_key(key) + } + + pub fn insert(&mut self, key: String, value: Vec) { + self.queries.insert(key, value); + } + + pub fn get_mut(&mut self, key: &str) -> Option<&Vec> { + self.queries.get(key) + } +} diff --git a/src/types/request.rs b/src/types/request.rs index 3d1b151fc..cf1c3e385 100644 --- a/src/types/request.rs +++ b/src/types/request.rs @@ -6,11 +6,11 @@ use std::collections::HashMap; use crate::types::{check_body_type, get_body_from_pyobject, Url}; -use super::identity::Identity; +use super::{identity::Identity, multimap::QueryParams}; #[derive(Default, Debug, Clone, FromPyObject)] pub struct Request { - pub queries: HashMap, + pub query_params: QueryParams, pub headers: HashMap, pub method: String, pub path_params: HashMap, @@ -24,7 +24,7 @@ pub struct Request { impl ToPyObject for Request { fn to_object(&self, py: Python) -> PyObject { - let queries = self.queries.clone().into_py(py).extract(py).unwrap(); + let query_params = self.query_params.clone(); let headers = self.headers.clone().into_py(py).extract(py).unwrap(); let path_params = self.path_params.clone().into_py(py).extract(py).unwrap(); let body = match String::from_utf8(self.body.clone()) { @@ -33,7 +33,7 @@ impl ToPyObject for Request { }; let request = PyRequest { - queries, + query_params, path_params, headers, body, @@ -52,12 +52,15 @@ impl Request { body: Bytes, global_headers: &DashMap, ) -> Self { - let mut queries = HashMap::new(); + let mut query_params: QueryParams = QueryParams::new(); if !req.query_string().is_empty() { let split = req.query_string().split('&'); for s in split { let path_params = s.split_once('=').unwrap_or((s, "")); - queries.insert(path_params.0.to_string(), path_params.1.to_string()); + let key = path_params.0.to_string(); + let value = path_params.1.to_string(); + + query_params.set(key, value); } } let headers = req @@ -79,7 +82,7 @@ impl Request { let ip_addr = req.peer_addr().map(|val| val.ip().to_string()); Self { - queries, + query_params, headers, method: req.method().as_str().to_owned(), path_params: HashMap::new(), @@ -95,7 +98,7 @@ impl Request { #[derive(Clone)] pub struct PyRequest { #[pyo3(get, set)] - pub queries: Py, + pub query_params: QueryParams, #[pyo3(get, set)] pub headers: Py, #[pyo3(get, set)] @@ -117,7 +120,7 @@ impl PyRequest { #[new] #[allow(clippy::too_many_arguments)] pub fn new( - queries: Py, + query_params: &PyDict, headers: Py, path_params: Py, body: Py, @@ -126,8 +129,10 @@ impl PyRequest { identity: Option, ip_addr: Option, ) -> Self { + let query_params = QueryParams::from_dict(query_params); + Self { - queries, + query_params, headers, path_params, identity, diff --git a/unit_tests/test_request_object.py b/unit_tests/test_request_object.py index a4f731881..fccc7fa8a 100644 --- a/unit_tests/test_request_object.py +++ b/unit_tests/test_request_object.py @@ -8,7 +8,7 @@ def test_request_object(): path="/user", ) request = Request( - queries={}, + query_params={}, headers={"Content-Type": "application/json"}, path_params={}, body="",