diff --git a/ols/app/models/models.py b/ols/app/models/models.py index 398e14d7..20c03a05 100644 --- a/ols/app/models/models.py +++ b/ols/app/models/models.py @@ -2,7 +2,7 @@ import json from collections import OrderedDict -from typing import Optional, Self +from typing import Any, Dict, Optional, Self, Union from langchain.llms.base import LLM from pydantic import BaseModel, field_validator, model_validator @@ -713,7 +713,7 @@ def cache_entries_to_history( class MessageEncoder(json.JSONEncoder): """Convert Message objects to serializable dictionaries.""" - def default(self, o): + def default(self, o: Any) -> Union[dict, Any]: """Convert a Message object into a serializable dictionary. This method is called when an object cannot be serialized by default @@ -756,11 +756,11 @@ class MessageDecoder(json.JSONDecoder): HumanMessage(content="Hello", ...) """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): """Initialize the MessageDecoder with custom object hook.""" super().__init__(object_hook=self._decode_message, *args, **kwargs) - def _decode_message(self, dct): + def _decode_message(self, dct: Dict[str, Any]) -> Union[HumanMessage, AIMessage, Dict[str, Any]]: """Decode JSON dictionary into Message objects if applicable. Args: