-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
102 lines (87 loc) · 3.18 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import typing
from functools import partial
import anyio
from fastapi.responses import StreamingResponse
from starlette.types import Send, Scope, Receive
class PathMatchingTree:
"""
PathMatchingTree is a data structure that can be used to match a path with a value.
It supports exact match, partial match, and wildcard match.
For example, if the tree is built with the following config:
{
"/foo/bar": "value1",
"/baz/qux": "value2",
"/foo/*": "value3",
"/foo/*/bar": "value4"
}
Then the following path will match the corresponding value:
/foo/bar -> value1
/baz/qux -> value2
/foo/baz -> value3
/foo/baz/bar -> value4
/foo/baz/bar2 -> value3
"""
child = dict
value = None
def __init__(self, config):
self.child = {}
self._build_tree(config)
def _build_tree(self, config):
for k, v in config.items():
parts = k.split('/')
self._add(parts, v)
def _add(self, parts, value):
node = self
for part in parts:
if part == '':
continue
if part not in node.child:
node.child[part] = PathMatchingTree(dict())
node = node.child[part]
node.value = value
def get_matching(self, path):
parts = path.split('/')
matched = self
for part in parts:
if part == '':
continue
if part in matched.child:
matched = matched.child[part]
elif '*' in matched.child:
matched = matched.child['*']
else:
break
return matched.value
class OverrideStreamResponse(StreamingResponse):
"""
Override StreamingResponse to support lazy send response status_code and response headers
"""
async def stream_response(self, send: Send) -> None:
first_chunk = True
async for chunk in self.body_iterator:
if first_chunk:
await self.send_request_header(send)
first_chunk = False
if not isinstance(chunk, bytes):
chunk = chunk.encode(self.charset)
await send({'type': 'http.response.body', 'body': chunk, 'more_body': True})
if first_chunk:
await self.send_request_header(send)
await send({'type': 'http.response.body', 'body': b'', 'more_body': False})
async def send_request_header(self, send: Send) -> None:
await send(
{
'type': 'http.response.start',
'status': self.status_code,
'headers': self.raw_headers,
}
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:
async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None:
await func()
await task_group.cancel_scope.cancel()
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
if self.background is not None:
await self.background()