Skip to content

Commit

Permalink
fix: retries should retry timeouted prompts (langfuse#855)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxdeichmann authored Aug 11, 2024
1 parent 10e4916 commit 16885c7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
31 changes: 19 additions & 12 deletions langfuse/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import typing
import uuid
import backoff
import httpx
from enum import Enum
import time
Expand Down Expand Up @@ -1092,20 +1093,26 @@ def _fetch_prompt_and_update_cache(
)

self.log.debug(f"Fetching prompt '{cache_key}' from server...")
promptResponse = self.client.prompts.get(
self._url_encode(name),
version=version,
label=label,
request_options={
"max_retries": max_retries,
"timeout": fetch_timeout_seconds,
},
)

if promptResponse.type == "chat":
prompt = ChatPromptClient(promptResponse)
@backoff.on_exception(backoff.constant, Exception, max_tries=max_retries)
def fetch_prompts():
return self.client.prompts.get(
self._url_encode(name),
version=version,
label=label,
request_options={
"timeout_in_seconds": fetch_timeout_seconds,
}
if fetch_timeout_seconds is not None
else None,
)

prompt_response = fetch_prompts()

if prompt_response.type == "chat":
prompt = ChatPromptClient(prompt_response)
else:
prompt = TextPromptClient(promptResponse)
prompt = TextPromptClient(prompt_response)

self.prompt_cache.set(cache_key, prompt, ttl_seconds)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,9 @@ def test_get_langchain_prompt_with_jinja2():
labels=["production"],
)

langfuse_prompt = langfuse.get_prompt("test_jinja2")
langfuse_prompt = langfuse.get_prompt(
"test_jinja2", fetch_timeout_seconds=1, max_retries=3
)

assert (
langfuse_prompt.get_langchain_prompt()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def test_get_fresh_prompt(langfuse):
prompt_name,
version=None,
label=None,
request_options={"max_retries": 2, "timeout": None},
request_options=None,
)

assert result == TextPromptClient(prompt)
Expand Down Expand Up @@ -478,7 +478,7 @@ def test_using_custom_prompt_timeouts(langfuse):
prompt_name,
version=None,
label=None,
request_options={"max_retries": 2, "timeout": 1000},
request_options={"timeout_in_seconds": 1000},
)

assert result == TextPromptClient(prompt)
Expand Down Expand Up @@ -734,7 +734,7 @@ def test_get_expired_prompt_when_failing_fetch(mock_time, langfuse):

mock_server_call.side_effect = Exception("Server error")

result_call_2 = langfuse.get_prompt(prompt_name)
result_call_2 = langfuse.get_prompt(prompt_name, max_retries=1)
assert mock_server_call.call_count == 2
assert result_call_2 == prompt_client

Expand Down

0 comments on commit 16885c7

Please sign in to comment.