Skip to content

Commit

Permalink
Merge branch 'main' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
koxudaxi authored Jan 12, 2025
2 parents cd47554 + de3b2cd commit 02fa8e5
Show file tree
Hide file tree
Showing 10 changed files with 392 additions and 26 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ Model customization:
pydantic: datetime, dataclass: str, ...)
--reuse-model Reuse models on the field when a module has the model with the same
content
--target-python-version {3.6,3.7,3.8,3.9,3.10,3.11,3.12}
--target-python-version {3.6,3.7,3.8,3.9,3.10,3.11,3.12,3.13}
target python version (default: 3.8)
--treat-dot-as-module
treat dotted module names as modules
Expand Down
34 changes: 23 additions & 11 deletions datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from datamodel_code_generator.format import (
DatetimeClassType,
PythonVersion,
black_find_project_root,
is_supported_in_black,
)
from datamodel_code_generator.parser import LiteralType
Expand Down Expand Up @@ -366,6 +365,26 @@ def merge_args(self, args: Namespace) -> None:
setattr(self, field_name, getattr(parsed_args, field_name))


def _get_pyproject_toml_config(source: Path) -> Optional[Dict[str, Any]]:
"""Find and return the [tool.datamodel-codgen] section of the closest
pyproject.toml if it exists.
"""

current_path = source
while current_path != current_path.parent:
if (current_path / 'pyproject.toml').is_file():
pyproject_toml = load_toml(current_path / 'pyproject.toml')
if 'datamodel-codegen' in pyproject_toml.get('tool', {}):
return pyproject_toml['tool']['datamodel-codegen']

if (current_path / '.git').exists():
# Stop early if we see a git repository root.
return None

current_path = current_path.parent
return None


def main(args: Optional[Sequence[str]] = None) -> Exit:
"""Main function."""

Expand All @@ -383,16 +402,9 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
print(version)
exit(0)

root = black_find_project_root((Path().resolve(),))
pyproject_toml_path = root / 'pyproject.toml'
if pyproject_toml_path.is_file():
pyproject_toml: Dict[str, Any] = {
k.replace('-', '_'): v
for k, v in load_toml(pyproject_toml_path)
.get('tool', {})
.get('datamodel-codegen', {})
.items()
}
pyproject_config = _get_pyproject_toml_config(Path().resolve())
if pyproject_config is not None:
pyproject_toml = {k.replace('-', '_'): v for k, v in pyproject_config.items()}
else:
pyproject_toml = {}

Expand Down
18 changes: 15 additions & 3 deletions datamodel_code_generator/model/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,22 @@ def get_member(self, field: DataModelFieldBase) -> Member:

def find_member(self, value: Any) -> Optional[Member]:
repr_value = repr(value)
for field in self.fields: # pragma: no cover
if field.default == repr_value:
# Remove surrounding quotes from the string representation
str_value = str(value).strip('\'"')

for field in self.fields:
# Remove surrounding quotes from field default value
field_default = field.default.strip('\'"')

# Compare values after removing quotes
if field_default == str_value:
return self.get_member(field)
return None # pragma: no cover

# Keep original comparison for backwards compatibility
if field.default == repr_value: # pragma: no cover
return self.get_member(field)

return None

@property
def imports(self) -> Tuple[Import, ...]:
Expand Down
1 change: 1 addition & 0 deletions datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def get_special_path(keyword: str, path: List[str]) -> List[str]:

escape_characters = str.maketrans(
{
'\u0000': r'\x00', # Null byte
'\\': r'\\',
"'": r'\'',
'\b': r'\b',
Expand Down
20 changes: 11 additions & 9 deletions datamodel_code_generator/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,22 +362,21 @@ def all_imports(self) -> Iterator[Import]:

@property
def imports(self) -> Iterator[Import]:
# Add base import if exists
if self.import_:
yield self.import_

# Define required imports based on type features and conditions
imports: Tuple[Tuple[bool, Import], ...] = (
(self.is_optional and not self.use_union_operator, IMPORT_OPTIONAL),
(len(self.data_types) > 1 and not self.use_union_operator, IMPORT_UNION),
)
if any(self.literals):
import_literal = (
(
bool(self.literals),
IMPORT_LITERAL
if self.python_version.has_literal_type
else IMPORT_LITERAL_BACKPORT
)
imports = (
*imports,
(any(self.literals), import_literal),
)
else IMPORT_LITERAL_BACKPORT,
),
)

if self.use_generic_container:
if self.use_standard_collections:
Expand All @@ -401,10 +400,13 @@ def imports(self) -> Iterator[Import]:
(self.is_set, IMPORT_SET),
(self.is_dict, IMPORT_DICT),
)

# Yield imports based on conditions
for field, import_ in imports:
if field and import_ != self.import_:
yield import_

# Propagate imports from any dict_key type
if self.dict_key:
yield from self.dict_key.imports

Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ Model customization:
pydantic: datetime, dataclass: str, ...)
--reuse-model Reuse models on the field when a module has the model with the same
content
--target-python-version {3.6,3.7,3.8,3.9,3.10,3.11,3.12}
--target-python-version {3.6,3.7,3.8,3.9,3.10,3.11,3.12,3.13}
target python version (default: 3.8)
--treat-dot-as-module
treat dotted module names as modules
Expand Down
69 changes: 69 additions & 0 deletions tests/data/expected/main_kr/pyproject/output.strictstr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# generated by datamodel-codegen:
# filename: api.yaml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import List, Optional

from pydantic import AnyUrl, BaseModel, Field, StrictStr


class Pet(BaseModel):
id: int
name: StrictStr
tag: Optional[StrictStr] = None


class Pets(BaseModel):
__root__: List[Pet]


class User(BaseModel):
id: int
name: StrictStr
tag: Optional[StrictStr] = None


class Users(BaseModel):
__root__: List[User]


class Id(BaseModel):
__root__: StrictStr


class Rules(BaseModel):
__root__: List[StrictStr]


class Error(BaseModel):
code: int
message: StrictStr


class Api(BaseModel):
apiKey: Optional[StrictStr] = Field(
None, description='To be used as a dataset parameter value'
)
apiVersionNumber: Optional[StrictStr] = Field(
None, description='To be used as a version parameter value'
)
apiUrl: Optional[AnyUrl] = Field(
None, description="The URL describing the dataset's fields"
)
apiDocumentationUrl: Optional[AnyUrl] = Field(
None, description='A URL to the API console for each API'
)


class Apis(BaseModel):
__root__: List[Api]


class Event(BaseModel):
name: Optional[StrictStr] = None


class Result(BaseModel):
event: Optional[Event] = None
67 changes: 67 additions & 0 deletions tests/main/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from datamodel_code_generator.format import PythonVersion
from datamodel_code_generator.imports import (
IMPORT_LITERAL,
IMPORT_LITERAL_BACKPORT,
IMPORT_OPTIONAL,
)
from datamodel_code_generator.types import DataType


class TestDataType:
def test_imports_with_literal_one(self):
"""Test imports for a DataType with single literal value"""
data_type = DataType(literals=[''], python_version=PythonVersion.PY_38)

# Convert iterator to list for assertion
imports = list(data_type.imports)
assert IMPORT_LITERAL in imports
assert len(imports) == 1

def test_imports_with_literal_one_and_optional(self):
"""Test imports for an optional DataType with single literal value"""
data_type = DataType(
literals=[''], is_optional=True, python_version=PythonVersion.PY_38
)

imports = list(data_type.imports)
assert IMPORT_LITERAL in imports
assert IMPORT_OPTIONAL in imports
assert len(imports) == 2

def test_imports_with_literal_empty(self):
"""Test imports for a DataType with no literal values"""
data_type = DataType(literals=[], python_version=PythonVersion.PY_38)

imports = list(data_type.imports)
assert len(imports) == 0

def test_imports_with_literal_python37(self):
"""Test imports for a DataType with literal in Python 3.7"""
data_type = DataType(literals=[''], python_version=PythonVersion.PY_37)

imports = list(data_type.imports)
assert IMPORT_LITERAL_BACKPORT in imports
assert len(imports) == 1

def test_imports_with_nested_dict_key(self):
"""Test imports for a DataType with dict_key containing literals"""
dict_key_type = DataType(literals=['key'], python_version=PythonVersion.PY_38)

data_type = DataType(python_version=PythonVersion.PY_38, dict_key=dict_key_type)

imports = list(data_type.imports)
assert IMPORT_LITERAL in imports
assert len(imports) == 1

def test_imports_without_duplicate_literals(self):
"""Test that literal import is not duplicated"""
dict_key_type = DataType(literals=['key1'], python_version=PythonVersion.PY_38)

data_type = DataType(
literals=['key2'],
python_version=PythonVersion.PY_38,
dict_key=dict_key_type,
)

imports = list(data_type.imports)
assert IMPORT_LITERAL in imports
Loading

0 comments on commit 02fa8e5

Please sign in to comment.