Skip to content

Commit

Permalink
fix: fix clientlets
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Sep 15, 2024
1 parent 361f08c commit fbada7f
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 90 deletions.
1 change: 1 addition & 0 deletions jina/clients/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(

async def close(self):
"""Closes the potential resources of the Client.
:return: Return whatever a close method may return
"""
return self.teardown_instrumentation()

Expand Down
12 changes: 8 additions & 4 deletions jina/clients/base/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ async def start(self):
"""
with ImportExtensions(required=True):
import aiohttp

self.session = aiohttp.ClientSession(
**self._session_kwargs, trace_configs=self._trace_config
)
Expand All @@ -154,6 +153,7 @@ class HTTPClientlet(AioHttpClientlet):
async def send_message(self, url, request: 'Request'):
"""Sends a POST request to the server
:param url: the URL where to send the message
:param request: request as dict
:return: send post message
"""
Expand All @@ -170,14 +170,15 @@ async def send_message(self, url, request: 'Request'):
from docarray.base_doc.io.json import orjson_dumps

request_kwargs['data'] = JinaJsonPayload(value=req_dict)

async with self.session.post(**request_kwargs) as response:
try:
r_str = await response.json()
except aiohttp.ContentTypeError:
r_str = await response.text()
r_status = response.status
handle_response_status(response.status, r_str, url)
return r_status, r_str
handle_response_status(r_status, r_str, url)
return r_status, r_str
except (ValueError, ConnectionError, BadClient, aiohttp.ClientError, aiohttp.ClientConnectionError) as err:
self.logger.debug(f'Got an error: {err} sending POST to {url} in attempt {attempt}/{self.max_attempts}')
await retry.wait_or_raise_err(
Expand All @@ -196,6 +197,7 @@ async def send_message(self, url, request: 'Request'):
async def send_streaming_message(self, url, doc: 'Document', on: str):
"""Sends a GET SSE request to the server
:param url: the URL where to send the message
:param doc: Request Document
:param on: Request endpoint
:yields: responses
Expand All @@ -218,6 +220,7 @@ async def send_streaming_message(self, url, doc: 'Document', on: str):

async def send_dry_run(self, url, **kwargs):
"""Query the dry_run endpoint from Gateway
:param url: the URL where to send the message
:param kwargs: keyword arguments to make sure compatible API with other clients
:return: send get message
"""
Expand Down Expand Up @@ -264,8 +267,9 @@ async def __anext__(self):
class WebsocketClientlet(AioHttpClientlet):
"""Websocket Client to be used with the streamer"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, url, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.url = url
self.websocket = None
self.response_iter = None

Expand Down
154 changes: 85 additions & 69 deletions jina/clients/base/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ class HTTPBaseClient(BaseClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._endpoints = []
self.reuse_session = False
self._lock = AsyncExitStack()
self.iolet = None

async def close(self):
await super().close()
"""Closes the potential resources of the Client.
:return: Return whatever a close method may return
"""
ret = super().close()
if self.iolet is not None:
await self.iolet.__aexit__()
await self.iolet.__aexit__(None, None, None)
return ret

async def _get_endpoints_from_openapi(self, **kwargs):
def extract_paths_by_method(spec):
Expand Down Expand Up @@ -76,25 +82,26 @@ async def _is_flow_ready(self, **kwargs) -> bool:
proto = 'https' if self.args.tls else 'http'
url = f'{proto}://{self.args.host}:{self.args.port}/dry_run'

if self.iolet is not None and self.args.reuse_session:
iolet = self.iolet
else:
iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)

if self.args.reuse_session and self.iolet is None:
self.iolet = iolet
await self.iolet.__aenter__()

if not self.args.reuse_session:
if not self.reuse_session:
iolet = await stack.enter_async_context(
iolet
HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
)
else:
async with self._lock:
if self.iolet is None:
self.iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
await self.iolet.__aenter__()
iolet = self.iolet

response = await iolet.send_dry_run(**kwargs)
response = await iolet.send_dry_run(url=url, **kwargs)
r_status = response.status

r_str = await response.json()
Expand All @@ -112,20 +119,20 @@ async def _is_flow_ready(self, **kwargs) -> bool:
return False

async def _get_results(
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
max_attempts: int = 1,
initial_backoff: float = 0.5,
max_backoff: float = 0.1,
backoff_multiplier: float = 1.5,
results_in_order: bool = False,
prefetch: Optional[int] = None,
timeout: Optional[int] = None,
return_type: Type[DocumentArray] = DocumentArray,
**kwargs,
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
max_attempts: int = 1,
initial_backoff: float = 0.5,
max_backoff: float = 0.1,
backoff_multiplier: float = 1.5,
results_in_order: bool = False,
prefetch: Optional[int] = None,
timeout: Optional[int] = None,
return_type: Type[DocumentArray] = DocumentArray,
**kwargs,
):
"""
:param inputs: the callable
Expand Down Expand Up @@ -168,30 +175,27 @@ async def _get_results(
else:
url = f'{proto}://{self.args.host}:{self.args.port}/post'

if self.iolet is not None and self.args.reuse_session:
iolet = self.iolet
else:
iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
max_attempts=max_attempts,
initial_backoff=initial_backoff,
max_backoff=max_backoff,
backoff_multiplier=backoff_multiplier,
timeout=timeout,
**kwargs,
)
if self.args.reuse_session and self.iolet is None:
self.iolet = iolet
await self.iolet.__aenter__()

if not self.args.reuse_session:
if not self.reuse_session:
iolet = await stack.enter_async_context(
iolet
HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
)
else:
async with self._lock:
if self.iolet is None:
self.iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
self.iolet = await self.iolet.__aenter__()
iolet = self.iolet

def _request_handler(
request: 'Request', **kwargs
request: 'Request', **kwargs
) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]':
"""
For HTTP Client, for each request in the iterator, we `send_message` using
Expand All @@ -215,7 +219,7 @@ def _result_handler(result):
**streamer_args,
)
async for response in streamer.stream(
request_iterator=request_iterator, results_in_order=results_in_order
request_iterator=request_iterator, results_in_order=results_in_order
):
r_status, r_str = response
handle_response_status(r_status, r_str, url)
Expand Down Expand Up @@ -256,13 +260,13 @@ def _result_handler(result):
yield resp

async def _get_streaming_results(
self,
on: str,
inputs: 'Document',
parameters: Optional[Dict] = None,
return_type: Type[Document] = Document,
timeout: Optional[int] = None,
**kwargs,
self,
on: str,
inputs: 'Document',
parameters: Optional[Dict] = None,
return_type: Type[Document] = Document,
timeout: Optional[int] = None,
**kwargs,
):
proto = 'https' if self.args.tls else 'http'
endpoint = on.strip('/')
Expand All @@ -272,15 +276,27 @@ async def _get_streaming_results(
url = f'{proto}://{self.args.host}:{self.args.port}/{endpoint}'
else:
url = f'{proto}://{self.args.host}:{self.args.port}/default'

iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
timeout=timeout,
**kwargs,
)

async with iolet:
async with AsyncExitStack() as stack:
if not self.reuse_session:
iolet = await stack.enter_async_context(
HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
timeout=timeout,
**kwargs,
)
)
else:
async with self._lock:
if self.iolet is None:
self.iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
timeout=timeout,
**kwargs,
)
await self.iolet.__aenter__()
iolet = self.iolet
async for doc in iolet.send_streaming_message(url=url, doc=inputs, on=on):
if not docarray_v2:
yield Document.from_dict(json.loads(doc))
Expand Down
6 changes: 6 additions & 0 deletions jina/clients/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
PostMixin,
ProfileMixin,
)
import asyncio


class HTTPClient(
Expand Down Expand Up @@ -80,3 +81,8 @@ async def async_inputs():
print(resp)
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._lock = asyncio.Lock()
self.reuse_session = self.args.reuse_session
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,23 @@ def ping(self, **kwargs):
@pytest.mark.parametrize('prefetch', [1, 10])
@pytest.mark.parametrize('concurrent', [15])
@pytest.mark.parametrize('use_stream', [False, True])
@pytest.mark.parametrize('reuse_session', [True, False])
def test_concurrent_clients(
concurrent, protocol, shards, polling, prefetch, reraise, use_stream
concurrent, protocol, shards, polling, prefetch, reraise, use_stream, reuse_session
):

if not use_stream and protocol != 'grpc':
return

if reuse_session and protocol != 'http':
return

def pong(peer_hash, queue, resp: Response):
for d in resp.docs:
queue.put((peer_hash, d.text))

def peer_client(port, protocol, peer_hash, queue):
c = Client(protocol=protocol, port=port)
c = Client(protocol=protocol, port=port, reuse_session=reuse_session)
for _ in range(NUM_REQUESTS):
c.post(
'/ping',
Expand Down
14 changes: 10 additions & 4 deletions tests/integration/docarray_v2/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
)
@pytest.mark.parametrize('return_type', ['batch', 'singleton'])
@pytest.mark.parametrize('include_gateway', [True, False])
def test_singleton_return(ctxt_manager, protocols, return_type, include_gateway):
@pytest.mark.parametrize('reuse_session', [True, False])
def test_singleton_return(ctxt_manager, protocols, return_type, include_gateway, reuse_session):
if reuse_session and 'http' not in protocols:
return
if 'websocket' in protocols and ctxt_manager != 'flow':
return
if not include_gateway and ctxt_manager == 'flow':
Expand Down Expand Up @@ -63,7 +66,7 @@ def foo_single(

with ctxt:
for port, protocol in zip(ports, protocols):
c = Client(port=port, protocol=protocol)
c = Client(port=port, protocol=protocol, reuse_session=reuse_session)
docs = c.post(
on='/foo',
inputs=MySingletonReturnInputDoc(text='hello', price=2),
Expand Down Expand Up @@ -102,7 +105,10 @@ def foo_single(
'protocols', [['grpc'], ['http'], ['websocket'], ['grpc', 'http']]
)
@pytest.mark.parametrize('return_type', ['batch', 'singleton'])
def test_singleton_return_async(ctxt_manager, protocols, return_type):
@pytest.mark.parametrize('reuse_session', [True, False])
def test_singleton_return_async(ctxt_manager, protocols, return_type, reuse_session):
if reuse_session and 'http' not in protocols:
return
if 'websocket' in protocols and ctxt_manager != 'flow':
return

Expand Down Expand Up @@ -149,7 +155,7 @@ async def foo_single(

with ctxt:
for port, protocol in zip(ports, protocols):
c = Client(port=port, protocol=protocol)
c = Client(port=port, protocol=protocol, reuse_session=reuse_session)
docs = c.post(
on='/foo',
inputs=MySingletonReturnInputDoc(text='hello', price=2),
Expand Down
Loading

0 comments on commit fbada7f

Please sign in to comment.