Skip to content

Commit

Permalink
fix: add a multi map data structure to query parameters (#699)
Browse files Browse the repository at this point in the history
* fix: add a multi map data structure to query parameters

* Add docstrings

* Update documentation

* fix: improve one public api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
sansyrox and pre-commit-ci[bot] authored Nov 25, 2023
1 parent ed4f79e commit 8425e7f
Show file tree
Hide file tree
Showing 12 changed files with 234 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ Batman was curious about how to access path parameters and query parameters from
```python {{ title: 'untyped' }}
@app.get("/query")
async def query_get(request):
query_data = request["queries"]
query_data = request.query_params.to_dict()
return jsonify(query_data)
```

Expand All @@ -239,7 +239,7 @@ Batman was curious about how to access path parameters and query parameters from

@app.get("/query")
async def query_get(request: Request):
query_data = request["queries"]
query_data = request.query_params.to_dict()
return jsonify(query_data)
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The request object is created in Rust side but is exposed to Python as a datacla
Attributes:
</li>
<li>
queries (dict[str, str]): The query parameters of the request. `e.g. /user?id=123 -> {"id": "123"}`
query_params (QueryParams): The query parameters of the request. `e.g. /user?id=123 -> {"id": [ "123" ]}`
</li>
<li>
headers (dict[str, str]): The headers of the request. `e.g. {"Content-Type": "application/json"}`
Expand Down Expand Up @@ -54,7 +54,7 @@ identity (Optional[Identity]): The identity of the client
@dataclass
class Request:
"""
queries: dict[str, str]
query_params: QueryParams
headers: dict[str, str]
path_params: dict[str, str]
body: Union[str, bytes]
Expand All @@ -69,7 +69,7 @@ identity (Optional[Identity]): The identity of the client
@dataclass
class Request:
"""
queries: dict[str, str]
query_params: QueryParams
headers: dict[str, str]
path_params: dict[str, str]
body: Union[str, bytes]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ async def add_crime(request):
@app.get("/crimes")
async def get_crimes(request):
with SessionLocal() as db:
skip = request.queries.get("skip", 0)
limit = request.queries.get("limit", 100)
skip = request.query_params.get("skip", 0)
limit = request.query_params.get("limit", 100)
crimes = crud.get_crimes(db, skip=skip, limit=limit)

return crimes
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,13 +448,13 @@ async def file_download_async():

@app.get("/sync/queries")
def sync_queries(request: Request):
query_data = request.queries
query_data = request.query_params.to_dict()
return jsonify(query_data)


@app.get("/async/queries")
async def async_query(request: Request):
query_data = request.queries
query_data = request.query_params.to_dict()
return jsonify(query_data)


Expand Down
2 changes: 1 addition & 1 deletion integration_tests/test_get_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def check_response(r: Response):
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_queries(function_type: str, session):
r = get(f"/{function_type}/queries?hello=robyn")
assert r.json() == {"hello": "robyn"}
assert r.json() == {"hello": ["robyn"]}

r = get(f"/{function_type}/queries")
assert r.json() == {}
94 changes: 92 additions & 2 deletions robyn/robyn.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,111 @@ class Url:
class Identity:
claims: dict[str, str]

@dataclass
class QueryParams:
"""
The query params object passed to the route handler.
Attributes:
queries (dict[str, list[str]]): The query parameters of the request. e.g. /user?id=123 -> {"id": "123"}
"""

queries: dict[str, list[str]]

def set(self, key: str, value: str) -> None:
"""
Sets the value of the query parameter with the given key.
If the key already exists, the value will be appended to the list of values.
Args:
key (str): The key of the query parameter
value (str): The value of the query parameter
"""
pass

def get(self, key: str, default: Optional[str]) -> Optional[str]:
"""
Gets the last value of the query parameter with the given key.
Args:
key (str): The key of the query parameter
default (Optional[str]): The default value if the key does not exist
"""
pass

def empty(self) -> bool:
"""
Returns:
True if the query params are empty, False otherwise
"""
pass

def contains(self, key: str) -> bool:
"""
Returns:
True if the query params contain the key, False otherwise
Args:
key (str): The key of the query parameter
"""
pass

def get_first(self, key: str) -> Optional[str]:
"""
Gets the first value of the query parameter with the given key.
Args:
key (str): The key of the query parameter
"""
pass

def get_all(self, key: str) -> Optional[list[str]]:
"""
Gets all the values of the query parameter with the given key.
Args:
key (str): The key of the query parameter
"""
pass

def extend(self, other: QueryParams) -> None:
"""
Extends the query params with the other query params.
Args:
other (QueryParams): The other QueryParams object
"""
pass

def to_dict(self) -> dict[str, list[str]]:
"""
Returns:
The query params as a dictionary
"""
pass

def __contains__(self, key: str) -> bool:
pass

def __repr__(self) -> str:
pass

@dataclass
class Request:
"""
The request object passed to the route handler.
Attributes:
queries (dict[str, str]): The query parameters of the request. e.g. /user?id=123 -> {"id": "123"}
query_params (QueryParams): The query parameters of the request. e.g. /user?id=123 -> {"id": "123"}
headers (dict[str, str]): The headers of the request. e.g. {"Content-Type": "application/json"}
params (dict[str, str]): The parameters of the request. e.g. /user/:id -> {"id": "123"}
body (Union[str, bytes]): The body of the request. If the request is a JSON, it will be a dict.
method (str): The method of the request. e.g. GET, POST, PUT, DELETE
ip_addr (Optional[str]): The IP Address of the client
"""

queries: dict[str, str]
query_params: QueryParams
headers: dict[str, str]
path_params: dict[str, str]
body: Union[str, bytes]
Expand Down
1 change: 1 addition & 0 deletions src/executors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ where
// Execute the middleware function
// type T can be either Request (before middleware) or Response (after middleware)
// Return type can either be a Request or a Response, we wrap it inside an enum for easier handling
#[inline]
pub async fn execute_middleware_function<T>(
input: &T,
function: &FunctionInfo,
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use pyo3::{exceptions::PyValueError, prelude::*};
use types::{
function_info::{FunctionInfo, MiddlewareType},
identity::Identity,
multimap::QueryParams,
request::PyRequest,
response::PyResponse,
HttpMethod, Url,
Expand Down Expand Up @@ -62,6 +63,7 @@ pub fn robyn(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyRequest>()?;
m.add_class::<PyResponse>()?;
m.add_class::<Url>()?;
m.add_class::<QueryParams>()?;
m.add_class::<MiddlewareType>()?;
m.add_class::<HttpMethod>()?;

Expand Down
1 change: 1 addition & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::{

pub mod function_info;
pub mod identity;
pub mod multimap;
pub mod request;
pub mod response;

Expand Down
112 changes: 112 additions & 0 deletions src/types/multimap.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use log::debug;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use std::collections::HashMap;

// Custom Multimap class
#[pyclass(name = "QueryParams")]
#[derive(Clone, Debug, Default)]
pub struct QueryParams {
pub queries: HashMap<String, Vec<String>>,
}

#[pymethods]
impl QueryParams {
#[new]
pub fn new() -> Self {
QueryParams {
queries: HashMap::new(),
}
}

pub fn set(&mut self, key: String, value: String) {
debug!("Setting key: {} to value: {}", key, value);
self.queries.entry(key).or_insert_with(Vec::new).push(value);
debug!("Multimap: {:?}", self.queries);
}

pub fn get(&self, key: String, default: Option<String>) -> Option<String> {
match self.queries.get(&key) {
Some(values) => values.last().cloned(),
None => default,
}
}

pub fn get_first(&self, key: String) -> Option<String> {
match self.queries.get(&key) {
Some(values) => values.first().cloned(),
None => None,
}
}

pub fn empty(&self) -> bool {
self.queries.is_empty()
}

pub fn contains(&self, key: String) -> bool {
self.queries.contains_key(&key)
}

pub fn get_all(&self, key: String) -> Option<Vec<String>> {
self.queries.get(&key).cloned()
}

pub fn extend(&mut self, other: &mut Self) {
for (key, values) in other.queries.iter_mut() {
for value in values.iter_mut() {
self.set(key.clone(), value.clone());
}
}
}

pub fn to_dict(&self, py: Python) -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
for (key, values) in self.queries.iter() {
let values = PyList::new(py, values.iter());
dict.set_item(key, values)?;
}
Ok(dict.into())
}

pub fn __contains__(&self, key: String) -> bool {
self.queries.contains_key(&key)
}

pub fn __repr__(&self) -> String {
format!("{:?}", self.queries)
}
}

impl QueryParams {
pub fn from_hashmap(map: HashMap<String, Vec<String>>) -> Self {
let mut multimap = QueryParams::new();
for (key, values) in map {
for value in values {
multimap.set(key.clone(), value);
}
}
multimap
}

pub fn from_dict(dict: &PyDict) -> Self {
let mut multimap = QueryParams::new();
for (key, value) in dict.iter() {
let key = key.extract::<String>().unwrap();
let value = value.extract::<String>().unwrap();
multimap.set(key, value);
}
multimap
}

pub fn contains_key(&self, key: &str) -> bool {
self.queries.contains_key(key)
}

pub fn insert(&mut self, key: String, value: Vec<String>) {
self.queries.insert(key, value);
}

pub fn get_mut(&mut self, key: &str) -> Option<&Vec<String>> {
self.queries.get(key)
}
}
Loading

0 comments on commit 8425e7f

Please sign in to comment.