diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 142e7eca1a84b..4dfd231a04016 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -2195,6 +2195,39 @@ def _resize(width: int, height: int) -> Tuple[int, int]: return width, height +def _update_schema_with_optional_fields(input_dict: dict) -> dict: + """Convert optional fields to required fields allowing 'null' type.""" + + def _update_properties(schema: dict) -> None: + if schema.get("type") != "object": + return + + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + for field, field_schema in properties.items(): + field_schema.pop("default", None) + + if field_schema.get("type") == "object": + _update_properties(field_schema) + + if field not in required_fields: + original_type = field_schema.get("type") + if isinstance(original_type, str): + field_schema["type"] = [original_type, "null"] + elif isinstance(original_type, list) and "null" not in original_type: + field_schema["type"].append("null") + + required_fields.append(field) + + schema["required"] = required_fields + + schema = input_dict.get("json_schema", {}).get("schema", {}) + _update_properties(schema) + + return input_dict + + def _convert_to_openai_response_format( schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None ) -> Union[Dict, TypeBaseModel]: @@ -2225,6 +2258,8 @@ def _convert_to_openai_response_format( f"'strict' is only specified in one place." ) raise ValueError(msg) + if strict: + _update_schema_with_optional_fields(response_format) return response_format diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 2e6cca0cd2d96..4494a1bccf1c0 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -18,7 +18,7 @@ ) from langchain_core.messages.ai import UsageMetadata from pydantic import BaseModel, Field -from typing_extensions import TypedDict +from typing_extensions import Annotated, TypedDict from langchain_openai import ChatOpenAI from langchain_openai.chat_models.base import ( @@ -822,6 +822,71 @@ def test__convert_to_openai_response_format() -> None: with pytest.raises(ValueError): _convert_to_openai_response_format(response_format, strict=False) + # Test handling of optional fields + ## TypedDict + class Entity(TypedDict): + """Extracted entity.""" + + animal: Annotated[str, ..., "The animal"] + color: Annotated[Optional[str], None, "The color"] + + actual = _convert_to_openai_response_format(Entity, strict=True) + expected = { + "type": "json_schema", + "json_schema": { + "name": "Entity", + "description": "Extracted entity.", + "strict": True, + "schema": { + "type": "object", + "properties": { + "animal": {"description": "The animal", "type": "string"}, + "color": {"description": "The color", "type": ["string", "null"]}, + }, + "required": ["animal", "color"], + "additionalProperties": False, + }, + }, + } + assert expected == actual + + ## JSON Schema + class EntityModel(BaseModel): + """Extracted entity.""" + + animal: str = Field(description="The animal") + color: Optional[str] = Field(default=None, description="The color") + + actual = _convert_to_openai_response_format( + EntityModel.model_json_schema(), strict=True + ) + expected = { + "type": "json_schema", + "json_schema": { + "name": "EntityModel", + "description": "Extracted entity.", + "strict": True, + "schema": { + "properties": { + "animal": { + "description": "The animal", + "title": "Animal", + "type": "string", + }, + "color": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "description": "The color", + "title": "Color", + }, + }, + "required": ["animal", "color"], + "type": "object", + "additionalProperties": False, + }, + }, + } + assert expected == actual + @pytest.mark.parametrize("method", ["function_calling", "json_schema"]) @pytest.mark.parametrize("strict", [True, None]) diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index f569135004497..26710f03cdc87 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -21,6 +21,7 @@ from pydantic import BaseModel, Field from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import Field as FieldV1 +from typing_extensions import Annotated, TypedDict from langchain_tests.unit_tests.chat_models import ( ChatModelTests, @@ -1293,6 +1294,7 @@ def has_tool_calling(self) -> bool: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") + # Pydantic class Joke(BaseModel): """Joke to tell user.""" @@ -1310,6 +1312,22 @@ class Joke(BaseModel): joke_result = chat.invoke("Give me a joke about cats, include the punchline.") assert isinstance(joke_result, Joke) + # Schema + chat = model.with_structured_output(Joke.model_json_schema()) + result = chat.invoke("Tell me a joke about cats.") + assert isinstance(result, dict) + + # TypedDict + class JokeDict(TypedDict): + """Joke to tell user.""" + + setup: Annotated[str, ..., "question to set up a joke"] + punchline: Annotated[Optional[str], None, "answer to resolve the joke"] + + chat = model.with_structured_output(JokeDict) + result = chat.invoke("Tell me a joke about cats.") + assert isinstance(result, dict) + def test_json_mode(self, model: BaseChatModel) -> None: """Test structured output via `JSON mode. `_