updated modelgen tool
This commit is contained in:
@@ -7,6 +7,7 @@ Supported generators:
|
||||
- TypeScriptGenerator: TypeScript interfaces
|
||||
- ProtobufGenerator: Protocol Buffer definitions
|
||||
- PrismaGenerator: Prisma schema
|
||||
- GrapheneGenerator: Graphene ObjectType/InputObjectType classes
|
||||
"""
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
@@ -6,12 +6,19 @@ Abstract base class for all code generators.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class BaseGenerator(ABC):
|
||||
"""Abstract base for code generators."""
|
||||
|
||||
def __init__(self, name_map: Dict[str, str] = None):
|
||||
self.name_map = name_map or {}
|
||||
|
||||
def map_name(self, name: str) -> str:
|
||||
"""Apply name_map to a model name."""
|
||||
return self.name_map.get(name, name)
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, models: Any, output_path: Path) -> None:
|
||||
"""Generate code for the given models to the specified path."""
|
||||
|
||||
@@ -224,7 +224,8 @@ class DjangoGenerator(BaseGenerator):
|
||||
if default is not dc.MISSING and isinstance(default, Enum):
|
||||
extra.append(f"default={enum_name}.{default.name}")
|
||||
return DJANGO_TYPES["enum"].format(
|
||||
enum_name=enum_name, opts=", " + ", ".join(extra) if extra else ""
|
||||
enum_name=enum_name,
|
||||
opts=", " + ", ".join(extra) if extra else ""
|
||||
)
|
||||
|
||||
# Text fields (based on name heuristics)
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
Pydantic Generator
|
||||
|
||||
Generates Pydantic BaseModel classes from model definitions.
|
||||
Supports two output modes:
|
||||
- File output: flat models (backwards compatible)
|
||||
- Directory output: CRUD variants (Create/Update/Response) per model
|
||||
"""
|
||||
|
||||
import dataclasses as dc
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, List, get_type_hints
|
||||
@@ -13,6 +17,13 @@ from ..loader.schema import EnumDefinition, FieldDefinition, ModelDefinition
|
||||
from ..types import PYDANTIC_RESOLVERS
|
||||
from .base import BaseGenerator
|
||||
|
||||
# Fields to skip per CRUD variant
|
||||
SKIP_FIELDS = {
|
||||
"Create": {"id", "created_at", "updated_at", "status", "error_message"},
|
||||
"Update": {"id", "created_at", "updated_at"},
|
||||
"Response": set(),
|
||||
}
|
||||
|
||||
|
||||
class PydanticGenerator(BaseGenerator):
|
||||
"""Generates Pydantic model files."""
|
||||
@@ -21,52 +32,187 @@ class PydanticGenerator(BaseGenerator):
|
||||
return ".py"
|
||||
|
||||
def generate(self, models, output_path: Path) -> None:
|
||||
"""Generate Pydantic models to output_path."""
|
||||
"""Generate Pydantic models to output_path.
|
||||
|
||||
If output_path is a directory (or doesn't end in .py), generate
|
||||
multi-file CRUD variants. Otherwise, generate flat models to a
|
||||
single file.
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
|
||||
if output_path.suffix != ".py":
|
||||
# Directory mode: CRUD variants
|
||||
self._generate_crud_directory(models, output_path)
|
||||
else:
|
||||
# File mode: flat models (backwards compatible)
|
||||
self._generate_flat_file(models, output_path)
|
||||
|
||||
def _generate_flat_file(self, models, output_path: Path) -> None:
|
||||
"""Generate flat models to a single file (original behavior)."""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Detect input type and generate accordingly
|
||||
if hasattr(models, "get_shared_component"):
|
||||
# ConfigLoader (soleprint config)
|
||||
content = self._generate_from_config(models)
|
||||
elif hasattr(models, "models"):
|
||||
# SchemaLoader
|
||||
content = self._generate_from_definitions(
|
||||
models.models, getattr(models, "enums", [])
|
||||
)
|
||||
elif isinstance(models, tuple):
|
||||
# (models, enums) tuple from extractor
|
||||
content = self._generate_from_definitions(models[0], models[1])
|
||||
elif isinstance(models, list):
|
||||
# List of dataclasses (MPR style)
|
||||
content = self._generate_from_dataclasses(models)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(models)}")
|
||||
|
||||
output_path.write_text(content)
|
||||
|
||||
def _generate_crud_directory(self, models, output_dir: Path) -> None:
|
||||
"""Generate CRUD variant files in a directory."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if hasattr(models, "models"):
|
||||
model_defs = models.models
|
||||
enum_defs = getattr(models, "enums", [])
|
||||
elif isinstance(models, tuple):
|
||||
model_defs = models[0]
|
||||
enum_defs = models[1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type for CRUD mode: {type(models)}")
|
||||
|
||||
# base.py
|
||||
base_content = "\n".join([
|
||||
'"""Pydantic Base Schema - GENERATED FILE"""',
|
||||
"",
|
||||
"from pydantic import BaseModel, ConfigDict",
|
||||
"",
|
||||
"",
|
||||
"class BaseSchema(BaseModel):",
|
||||
' """Base schema with ORM mode."""',
|
||||
" model_config = ConfigDict(from_attributes=True)",
|
||||
"",
|
||||
])
|
||||
(output_dir / "base.py").write_text(base_content)
|
||||
|
||||
# Per-model files
|
||||
imports = ["from .base import BaseSchema"]
|
||||
all_exports = ['"BaseSchema"']
|
||||
|
||||
for model_def in model_defs:
|
||||
mapped = self.map_name(model_def.name)
|
||||
module_name = mapped.lower()
|
||||
|
||||
lines = [
|
||||
f'"""{model_def.name} Schemas - GENERATED FILE"""',
|
||||
"",
|
||||
"from datetime import datetime",
|
||||
"from enum import Enum",
|
||||
"from typing import Any, Dict, List, Optional",
|
||||
"from uuid import UUID",
|
||||
"",
|
||||
"from .base import BaseSchema",
|
||||
"",
|
||||
]
|
||||
|
||||
# Inline enums used by this model
|
||||
model_enums = self._collect_model_enums(model_def, enum_defs)
|
||||
for enum_def in model_enums:
|
||||
lines.append("")
|
||||
lines.extend(self._generate_enum(enum_def))
|
||||
lines.append("")
|
||||
|
||||
# CRUD variants
|
||||
for suffix in ["Create", "Update", "Response"]:
|
||||
lines.append("")
|
||||
lines.extend(self._generate_crud_model(model_def, mapped, suffix))
|
||||
|
||||
lines.append("")
|
||||
content = "\n".join(lines)
|
||||
(output_dir / f"{module_name}.py").write_text(content)
|
||||
|
||||
# Track imports
|
||||
imports.append(
|
||||
f"from .{module_name} import {mapped}Create, {mapped}Update, {mapped}Response"
|
||||
)
|
||||
all_exports.extend([
|
||||
f'"{mapped}Create"', f'"{mapped}Update"', f'"{mapped}Response"'
|
||||
])
|
||||
|
||||
for enum_def in model_enums:
|
||||
imports.append(f"from .{module_name} import {enum_def.name}")
|
||||
all_exports.append(f'"{enum_def.name}"')
|
||||
|
||||
# __init__.py
|
||||
init_content = "\n".join([
|
||||
'"""API Schemas - GENERATED FILE"""',
|
||||
"",
|
||||
*imports,
|
||||
"",
|
||||
f"__all__ = [{', '.join(all_exports)}]",
|
||||
"",
|
||||
])
|
||||
(output_dir / "__init__.py").write_text(init_content)
|
||||
|
||||
def _collect_model_enums(
|
||||
self, model_def: ModelDefinition, enum_defs: List[EnumDefinition]
|
||||
) -> List[EnumDefinition]:
|
||||
"""Find enums referenced by a model's fields."""
|
||||
enum_names = set()
|
||||
for field in model_def.fields:
|
||||
base, _ = unwrap_optional(field.type_hint)
|
||||
if isinstance(base, type) and issubclass(base, Enum):
|
||||
enum_names.add(base.__name__)
|
||||
return [e for e in enum_defs if e.name in enum_names]
|
||||
|
||||
def _generate_crud_model(
|
||||
self, model_def: ModelDefinition, mapped_name: str, suffix: str
|
||||
) -> List[str]:
|
||||
"""Generate a single CRUD variant (Create/Update/Response)."""
|
||||
class_name = f"{mapped_name}{suffix}"
|
||||
skip = SKIP_FIELDS.get(suffix, set())
|
||||
|
||||
lines = [
|
||||
f"class {class_name}(BaseSchema):",
|
||||
f' """{class_name} schema."""',
|
||||
]
|
||||
|
||||
has_fields = False
|
||||
for field in model_def.fields:
|
||||
if field.name.startswith("_") or field.name in skip:
|
||||
continue
|
||||
|
||||
has_fields = True
|
||||
py_type = self._resolve_type(field.type_hint, field.optional)
|
||||
|
||||
# Update variant: all fields optional
|
||||
if suffix == "Update" and "Optional" not in py_type:
|
||||
py_type = f"Optional[{py_type}]"
|
||||
|
||||
default = self._format_default(field.default, "Optional" in py_type)
|
||||
lines.append(f" {field.name}: {py_type}{default}")
|
||||
|
||||
if not has_fields:
|
||||
lines.append(" pass")
|
||||
|
||||
return lines
|
||||
|
||||
# =========================================================================
|
||||
# Flat file generation (original behavior)
|
||||
# =========================================================================
|
||||
|
||||
def _generate_from_definitions(
|
||||
self, models: List[ModelDefinition], enums: List[EnumDefinition]
|
||||
) -> str:
|
||||
"""Generate from ModelDefinition objects (schema/extract mode)."""
|
||||
lines = self._generate_header()
|
||||
|
||||
# Generate enums
|
||||
for enum_def in enums:
|
||||
lines.extend(self._generate_enum(enum_def))
|
||||
lines.append("")
|
||||
|
||||
# Generate models
|
||||
for model_def in models:
|
||||
lines.extend(self._generate_model_from_definition(model_def))
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_from_dataclasses(self, dataclasses: List[type]) -> str:
|
||||
"""Generate from Python dataclasses (MPR style)."""
|
||||
lines = self._generate_header()
|
||||
|
||||
# Collect and generate enums first
|
||||
enums_generated = set()
|
||||
for cls in dataclasses:
|
||||
hints = get_type_hints(cls)
|
||||
@@ -77,16 +223,12 @@ class PydanticGenerator(BaseGenerator):
|
||||
lines.extend(self._generate_enum_from_python(base))
|
||||
lines.append("")
|
||||
enums_generated.add(base.__name__)
|
||||
|
||||
# Generate models
|
||||
for cls in dataclasses:
|
||||
lines.extend(self._generate_model_from_dataclass(cls))
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_header(self) -> List[str]:
|
||||
"""Generate file header."""
|
||||
return [
|
||||
'"""',
|
||||
"Pydantic Models - GENERATED FILE",
|
||||
@@ -104,27 +246,23 @@ class PydanticGenerator(BaseGenerator):
|
||||
]
|
||||
|
||||
def _generate_enum(self, enum_def: EnumDefinition) -> List[str]:
|
||||
"""Generate Pydantic enum from EnumDefinition."""
|
||||
lines = [f"class {enum_def.name}(str, Enum):"]
|
||||
for name, value in enum_def.values:
|
||||
lines.append(f' {name} = "{value}"')
|
||||
return lines
|
||||
|
||||
def _generate_enum_from_python(self, enum_cls: type) -> List[str]:
|
||||
"""Generate Pydantic enum from Python Enum."""
|
||||
lines = [f"class {enum_cls.__name__}(str, Enum):"]
|
||||
for member in enum_cls:
|
||||
lines.append(f' {member.name} = "{member.value}"')
|
||||
return lines
|
||||
|
||||
def _generate_model_from_definition(self, model_def: ModelDefinition) -> List[str]:
|
||||
"""Generate Pydantic model from ModelDefinition."""
|
||||
docstring = model_def.docstring or model_def.name
|
||||
lines = [
|
||||
f"class {model_def.name}(BaseModel):",
|
||||
f' """{docstring.strip().split(chr(10))[0]}"""',
|
||||
]
|
||||
|
||||
if not model_def.fields:
|
||||
lines.append(" pass")
|
||||
else:
|
||||
@@ -132,46 +270,34 @@ class PydanticGenerator(BaseGenerator):
|
||||
py_type = self._resolve_type(field.type_hint, field.optional)
|
||||
default = self._format_default(field.default, field.optional)
|
||||
lines.append(f" {field.name}: {py_type}{default}")
|
||||
|
||||
return lines
|
||||
|
||||
def _generate_model_from_dataclass(self, cls: type) -> List[str]:
|
||||
"""Generate Pydantic model from a dataclass."""
|
||||
import dataclasses as dc
|
||||
|
||||
docstring = cls.__doc__ or cls.__name__
|
||||
lines = [
|
||||
f"class {cls.__name__}(BaseModel):",
|
||||
f' """{docstring.strip().split(chr(10))[0]}"""',
|
||||
]
|
||||
|
||||
hints = get_type_hints(cls)
|
||||
fields = {f.name: f for f in dc.fields(cls)}
|
||||
|
||||
for name, type_hint in hints.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
field = fields.get(name)
|
||||
default_val = dc.MISSING
|
||||
if field:
|
||||
if field.default is not dc.MISSING:
|
||||
default_val = field.default
|
||||
|
||||
py_type = self._resolve_type(type_hint, False)
|
||||
default = self._format_default(default_val, "Optional" in py_type)
|
||||
lines.append(f" {name}: {py_type}{default}")
|
||||
|
||||
return lines
|
||||
|
||||
def _resolve_type(self, type_hint: Any, optional: bool) -> str:
|
||||
"""Resolve Python type to Pydantic type string."""
|
||||
base, is_optional = unwrap_optional(type_hint)
|
||||
optional = optional or is_optional
|
||||
origin = get_origin_name(base)
|
||||
type_name = get_type_name(base)
|
||||
|
||||
# Look up resolver
|
||||
resolver = (
|
||||
PYDANTIC_RESOLVERS.get(origin)
|
||||
or PYDANTIC_RESOLVERS.get(type_name)
|
||||
@@ -182,14 +308,10 @@ class PydanticGenerator(BaseGenerator):
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
result = resolver(base) if resolver else "str"
|
||||
return f"Optional[{result}]" if optional else result
|
||||
|
||||
def _format_default(self, default: Any, optional: bool) -> str:
|
||||
"""Format default value for field."""
|
||||
import dataclasses as dc
|
||||
|
||||
if optional:
|
||||
return " = None"
|
||||
if default is dc.MISSING or default is None:
|
||||
@@ -204,7 +326,6 @@ class PydanticGenerator(BaseGenerator):
|
||||
|
||||
def _generate_from_config(self, config) -> str:
|
||||
"""Generate from ConfigLoader (soleprint config.json mode)."""
|
||||
# Get component names from config
|
||||
config_comp = config.get_shared_component("config")
|
||||
data_comp = config.get_shared_component("data")
|
||||
|
||||
|
||||
@@ -26,11 +26,10 @@ class TypeScriptGenerator(BaseGenerator):
|
||||
|
||||
# Handle different input types
|
||||
if hasattr(models, "models"):
|
||||
# SchemaLoader
|
||||
# SchemaLoader — include api_models if present
|
||||
all_models = models.models + getattr(models, "api_models", [])
|
||||
content = self._generate_from_definitions(
|
||||
models.models,
|
||||
getattr(models, "enums", []),
|
||||
api_models=getattr(models, "api_models", []),
|
||||
all_models, getattr(models, "enums", [])
|
||||
)
|
||||
elif isinstance(models, tuple):
|
||||
# (models, enums) tuple
|
||||
@@ -44,10 +43,7 @@ class TypeScriptGenerator(BaseGenerator):
|
||||
output_path.write_text(content)
|
||||
|
||||
def _generate_from_definitions(
|
||||
self,
|
||||
models: List[ModelDefinition],
|
||||
enums: List[EnumDefinition],
|
||||
api_models: List[ModelDefinition] = None,
|
||||
self, models: List[ModelDefinition], enums: List[EnumDefinition]
|
||||
) -> str:
|
||||
"""Generate from ModelDefinition objects."""
|
||||
lines = self._generate_header()
|
||||
@@ -63,14 +59,6 @@ class TypeScriptGenerator(BaseGenerator):
|
||||
lines.extend(self._generate_interface_from_definition(model_def))
|
||||
lines.append("")
|
||||
|
||||
# Generate API request/response interfaces
|
||||
if api_models:
|
||||
lines.append("// API request/response types")
|
||||
lines.append("")
|
||||
for model_def in api_models:
|
||||
lines.extend(self._generate_interface_from_definition(model_def))
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_from_dataclasses(self, dataclasses: List[type]) -> str:
|
||||
|
||||
Reference in New Issue
Block a user