Skip to content

Commit

Permalink
Merge pull request #296 from tisnik/constant-query-endpoints
Browse files Browse the repository at this point in the history
Constant query endpoints
  • Loading branch information
tisnik authored Jan 20, 2025
2 parents 22764ee + b31cf21 commit 67059b4
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 110 deletions.
113 changes: 43 additions & 70 deletions tests/e2e/test_query_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@

from . import test_api

QUERY_ENDPOINT = "/v1/query"


def test_invalid_question():
"""Check the REST API /v1/query with POST HTTP method for invalid question."""
endpoint = "/v1/query"
with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint):
with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, QUERY_ENDPOINT):
cid = suid.get_suid()
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"conversation_id": cid, "query": "how to make burger?"},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand All @@ -47,10 +48,9 @@ def test_invalid_question():

def test_invalid_question_without_conversation_id():
"""Check the REST API /v1/query with invalid question and without conversation ID."""
endpoint = "/v1/query"
with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint):
with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, QUERY_ENDPOINT):
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"query": "how to make burger?"},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand Down Expand Up @@ -82,14 +82,13 @@ def test_invalid_question_without_conversation_id():

def test_query_call_without_payload():
"""Check the REST API /v1/query with POST HTTP method when no payload is provided."""
endpoint = "/v1/query"
with metrics_utils.RestAPICallCounterChecker(
pytest.metrics_client,
endpoint,
QUERY_ENDPOINT,
status_code=requests.codes.unprocessable_entity,
):
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
timeout=test_api.LLM_REST_API_TIMEOUT,
)
assert response.status_code == requests.codes.unprocessable_entity
Expand All @@ -103,14 +102,13 @@ def test_query_call_without_payload():

def test_query_call_with_improper_payload():
"""Check the REST API /v1/query with POST HTTP method when improper payload is provided."""
endpoint = "/v1/query"
with metrics_utils.RestAPICallCounterChecker(
pytest.metrics_client,
endpoint,
QUERY_ENDPOINT,
status_code=requests.codes.unprocessable_entity,
):
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"parameter": "this-is-not-proper-question-my-friend"},
timeout=test_api.NON_LLM_REST_API_TIMEOUT,
)
Expand All @@ -125,15 +123,13 @@ def test_query_call_with_improper_payload():

def test_valid_question_improper_conversation_id() -> None:
"""Check the REST API /v1/query with POST HTTP method for improper conversation ID."""
endpoint = "/v1/query"

with metrics_utils.RestAPICallCounterChecker(
pytest.metrics_client,
endpoint,
QUERY_ENDPOINT,
status_code=requests.codes.internal_server_error,
):
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"conversation_id": "not-uuid", "query": "what is kubernetes?"},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand All @@ -153,13 +149,11 @@ def test_valid_question_improper_conversation_id() -> None:
@retry(max_attempts=3, wait_between_runs=10)
def test_valid_question_missing_conversation_id() -> None:
"""Check the REST API /v1/query with POST HTTP method for missing conversation ID."""
endpoint = "/v1/query"

with metrics_utils.RestAPICallCounterChecker(
pytest.metrics_client, endpoint, status_code=requests.codes.ok
pytest.metrics_client, QUERY_ENDPOINT, status_code=requests.codes.ok
):
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"conversation_id": "", "query": "what is kubernetes?"},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand All @@ -179,18 +173,17 @@ def test_valid_question_missing_conversation_id() -> None:

def test_too_long_question() -> None:
"""Check the REST API /v1/query with too long question."""
endpoint = "/v1/query"
# let's make the query really large, larger that context window size
query = "what is kubernetes?" * 10000

with metrics_utils.RestAPICallCounterChecker(
pytest.metrics_client,
endpoint,
QUERY_ENDPOINT,
status_code=requests.codes.request_entity_too_large,
):
cid = suid.get_suid()
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"conversation_id": cid, "query": query},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand All @@ -207,12 +200,10 @@ def test_too_long_question() -> None:
@pytest.mark.rag
def test_valid_question() -> None:
"""Check the REST API /v1/query with POST HTTP method for valid question and no yaml."""
endpoint = "/v1/query"

with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint):
with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, QUERY_ENDPOINT):
cid = suid.get_suid()
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"conversation_id": cid, "query": "what is kubernetes?"},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand All @@ -237,12 +228,10 @@ def test_valid_question() -> None:
@pytest.mark.rag
def test_ocp_docs_version_same_as_cluster_version() -> None:
"""Check that the version of OCP docs matches the cluster we're on."""
endpoint = "/v1/query"

with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint):
with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, QUERY_ENDPOINT):
cid = suid.get_suid()
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={
"conversation_id": cid,
"query": "welcome openshift container platform documentation",
Expand All @@ -267,13 +256,12 @@ def test_valid_question_tokens_counter() -> None:
pytest.metrics_client
)

endpoint = "/v1/query"
with (
metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint),
metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, QUERY_ENDPOINT),
metrics_utils.TokenCounterChecker(pytest.metrics_client, model, provider),
):
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"query": "what is kubernetes?"},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand All @@ -287,13 +275,12 @@ def test_invalid_question_tokens_counter() -> None:
pytest.metrics_client
)

endpoint = "/v1/query"
with (
metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint),
metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, QUERY_ENDPOINT),
metrics_utils.TokenCounterChecker(pytest.metrics_client, model, provider),
):
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"query": "how to make burger?"},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand All @@ -307,11 +294,10 @@ def test_token_counters_for_query_call_without_payload() -> None:
pytest.metrics_client
)

endpoint = "/v1/query"
with (
metrics_utils.RestAPICallCounterChecker(
pytest.metrics_client,
endpoint,
QUERY_ENDPOINT,
status_code=requests.codes.unprocessable_entity,
),
metrics_utils.TokenCounterChecker(
Expand All @@ -323,7 +309,7 @@ def test_token_counters_for_query_call_without_payload() -> None:
),
):
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
timeout=test_api.LLM_REST_API_TIMEOUT,
)
assert response.status_code == requests.codes.unprocessable_entity
Expand All @@ -336,11 +322,10 @@ def test_token_counters_for_query_call_with_improper_payload() -> None:
pytest.metrics_client
)

endpoint = "/v1/query"
with (
metrics_utils.RestAPICallCounterChecker(
pytest.metrics_client,
endpoint,
QUERY_ENDPOINT,
status_code=requests.codes.unprocessable_entity,
),
metrics_utils.TokenCounterChecker(
Expand All @@ -352,7 +337,7 @@ def test_token_counters_for_query_call_with_improper_payload() -> None:
),
):
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"parameter": "this-is-not-proper-question-my-friend"},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand All @@ -364,11 +349,9 @@ def test_token_counters_for_query_call_with_improper_payload() -> None:
@retry(max_attempts=3, wait_between_runs=10)
def test_rag_question() -> None:
"""Ensure responses include rag references."""
endpoint = "/v1/query"

with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint):
with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, QUERY_ENDPOINT):
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"query": "what is openshift virtualization?"},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand All @@ -391,11 +374,10 @@ def test_rag_question() -> None:
@pytest.mark.cluster
def test_query_filter() -> None:
"""Ensure responses does not include filtered words and redacted words are not logged."""
endpoint = "/v1/query"
with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint):
with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, QUERY_ENDPOINT):
query = "what is foo in bar?"
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"query": query},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand Down Expand Up @@ -437,10 +419,9 @@ def test_query_filter() -> None:
@retry(max_attempts=3, wait_between_runs=10)
def test_conversation_history() -> None:
"""Ensure conversations include previous query history."""
endpoint = "/v1/query"
with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint):
with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, QUERY_ENDPOINT):
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={
"query": "what is ingress in kubernetes?",
},
Expand All @@ -458,7 +439,7 @@ def test_conversation_history() -> None:
# get the conversation id so we can reuse it for the follow up question
cid = json_response["conversation_id"]
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={"conversation_id": cid, "query": "what?"},
timeout=test_api.LLM_REST_API_TIMEOUT,
)
Expand All @@ -475,16 +456,14 @@ def test_conversation_history() -> None:

def test_query_with_provider_but_not_model() -> None:
"""Check the REST API /v1/query with POST HTTP method for provider specified, but no model."""
endpoint = "/v1/query"

with metrics_utils.RestAPICallCounterChecker(
pytest.metrics_client,
endpoint,
QUERY_ENDPOINT,
status_code=requests.codes.unprocessable_entity,
):
# just the provider is explicitly specified, but model selection is missing
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={
"conversation_id": "",
"query": "what is kubernetes?",
Expand All @@ -506,16 +485,14 @@ def test_query_with_provider_but_not_model() -> None:

def test_query_with_model_but_not_provider() -> None:
"""Check the REST API /v1/query with POST HTTP method for model specified, but no provider."""
endpoint = "/v1/query"

with metrics_utils.RestAPICallCounterChecker(
pytest.metrics_client,
endpoint,
QUERY_ENDPOINT,
status_code=requests.codes.unprocessable_entity,
):
# just model is explicitly specified, but provider selection is missing
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={
"conversation_id": "",
"query": "what is kubernetes?",
Expand All @@ -536,19 +513,17 @@ def test_query_with_model_but_not_provider() -> None:

def test_query_with_unknown_provider() -> None:
"""Check the REST API /v1/query with POST HTTP method for unknown provider specified."""
endpoint = "/v1/query"

# retrieve currently selected model
model, _ = metrics_utils.get_enabled_model_and_provider(pytest.metrics_client)

with metrics_utils.RestAPICallCounterChecker(
pytest.metrics_client,
endpoint,
QUERY_ENDPOINT,
status_code=requests.codes.unprocessable_entity,
):
# provider is unknown
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={
"conversation_id": "",
"query": "what is kubernetes?",
Expand All @@ -575,19 +550,17 @@ def test_query_with_unknown_provider() -> None:

def test_query_with_unknown_model() -> None:
"""Check the REST API /v1/query with POST HTTP method for unknown model specified."""
endpoint = "/v1/query"

# retrieve currently selected provider
_, provider = metrics_utils.get_enabled_model_and_provider(pytest.metrics_client)

with metrics_utils.RestAPICallCounterChecker(
pytest.metrics_client,
endpoint,
QUERY_ENDPOINT,
status_code=requests.codes.unprocessable_entity,
):
# model is unknown
response = pytest.client.post(
endpoint,
QUERY_ENDPOINT,
json={
"conversation_id": "",
"query": "what is kubernetes?",
Expand Down
Loading

0 comments on commit 67059b4

Please sign in to comment.