move to postgresql
This commit is contained in:
@@ -17,6 +17,7 @@ from .django import DjangoGenerator
|
||||
from .prisma import PrismaGenerator
|
||||
from .protobuf import ProtobufGenerator
|
||||
from .pydantic import PydanticGenerator
|
||||
from .sqlmodel import SQLModelGenerator
|
||||
from .strawberry import StrawberryGenerator
|
||||
from .typescript import TypeScriptGenerator
|
||||
|
||||
@@ -24,6 +25,7 @@ from .typescript import TypeScriptGenerator
|
||||
GENERATORS: Dict[str, Type[BaseGenerator]] = {
|
||||
"pydantic": PydanticGenerator,
|
||||
"django": DjangoGenerator,
|
||||
"sqlmodel": SQLModelGenerator,
|
||||
"typescript": TypeScriptGenerator,
|
||||
"ts": TypeScriptGenerator, # Alias
|
||||
"protobuf": ProtobufGenerator,
|
||||
|
||||
186
modelgen/generator/sqlmodel.py
Normal file
186
modelgen/generator/sqlmodel.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
SQLModel Generator
|
||||
|
||||
Generates SQLModel table classes from model definitions.
|
||||
Extends the Pydantic generator — SQLModel classes *are* Pydantic models
|
||||
with table=True and SQLAlchemy column config for JSON fields.
|
||||
"""
|
||||
|
||||
import dataclasses as dc
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any, List, get_type_hints
|
||||
|
||||
from ..helpers import get_origin_name, get_type_name, unwrap_optional
|
||||
from .pydantic import PydanticGenerator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Field resolvers — each returns a Field() string or None to fall through
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _resolve_special(name, _base, _origin, _optional, _default):
|
||||
"""id, created_at, updated_at get fixed Field() definitions."""
|
||||
specials = {
|
||||
"id": "Field(default_factory=uuid4, primary_key=True)",
|
||||
"created_at": "Field(default_factory=datetime.utcnow)",
|
||||
"updated_at": "Field(default_factory=datetime.utcnow)",
|
||||
}
|
||||
return specials.get(name)
|
||||
|
||||
|
||||
def _resolve_json(name, _base, origin, _optional, _default):
|
||||
"""Dict and List fields → sa_column=Column(JSON)."""
|
||||
mapping = {
|
||||
"dict": ("dict", "{}"),
|
||||
"list": ("list", "[]"),
|
||||
}
|
||||
entry = mapping.get(origin)
|
||||
if not entry:
|
||||
return None
|
||||
factory, server_default = entry
|
||||
return (
|
||||
f"Field(default_factory={factory}, "
|
||||
f"sa_column=Column(JSON, nullable=False, server_default='{server_default}'))"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_indexed(name, _base, _origin, optional, _default):
|
||||
"""Known indexed fields."""
|
||||
indexed = {"source_asset_id", "parent_job_id", "job_id", "canonical_name"}
|
||||
if name not in indexed:
|
||||
return None
|
||||
if optional:
|
||||
return "Field(default=None, index=True)"
|
||||
return "Field(index=True)"
|
||||
|
||||
|
||||
def _resolve_optional(_name, _base, _origin, optional, _default):
|
||||
"""Optional fields default to None."""
|
||||
if optional:
|
||||
return "None"
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_default(_name, _base, _origin, _optional, default):
|
||||
"""Fields with explicit defaults. Enum before str (str enums are both)."""
|
||||
if default is dc.MISSING or default is None:
|
||||
return None
|
||||
if isinstance(default, Enum):
|
||||
return f'"{default.value}"'
|
||||
if isinstance(default, bool):
|
||||
return str(default)
|
||||
if isinstance(default, (int, float)):
|
||||
return str(default)
|
||||
if isinstance(default, str):
|
||||
return f'"{default}"'
|
||||
return None
|
||||
|
||||
|
||||
# Resolver chain — first non-None result wins
|
||||
_FIELD_RESOLVERS = [
|
||||
_resolve_special,
|
||||
_resolve_json,
|
||||
_resolve_indexed,
|
||||
_resolve_optional,
|
||||
_resolve_default,
|
||||
]
|
||||
|
||||
|
||||
def _resolve_field(name, type_hint, default):
|
||||
"""Run the resolver chain for a field. Returns ' = ...' string."""
|
||||
base, is_optional = unwrap_optional(type_hint)
|
||||
origin = get_origin_name(base)
|
||||
|
||||
for resolver in _FIELD_RESOLVERS:
|
||||
result = resolver(name, base, origin, is_optional, default)
|
||||
if result is not None:
|
||||
return f" = {result}"
|
||||
return ""
|
||||
|
||||
|
||||
def _to_snake_plural(name):
|
||||
"""CamelCase → snake_case_plural for table names."""
|
||||
s = re.sub(r"(?<=[a-z])(?=[A-Z])", "_", name).lower()
|
||||
if s.endswith("y") and not s.endswith("ey"):
|
||||
return s[:-1] + "ies"
|
||||
if s.endswith("s"):
|
||||
return s + "es"
|
||||
return s + "s"
|
||||
|
||||
|
||||
_HEADER = [
|
||||
'"""',
|
||||
"SQLModel Table Models - GENERATED FILE",
|
||||
"",
|
||||
"Do not edit directly. Regenerate using modelgen.",
|
||||
'"""',
|
||||
"",
|
||||
"from datetime import datetime",
|
||||
"from enum import Enum",
|
||||
"from typing import Any, Dict, List, Optional",
|
||||
"from uuid import UUID, uuid4",
|
||||
"",
|
||||
"from sqlmodel import SQLModel, Field, Column",
|
||||
"from sqlalchemy import JSON",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
class SQLModelGenerator(PydanticGenerator):
|
||||
"""Generates SQLModel table classes."""
|
||||
|
||||
def _generate_header(self) -> List[str]:
|
||||
return list(_HEADER)
|
||||
|
||||
def _generate_model_from_dataclass(self, cls: type) -> List[str]:
|
||||
return _build_table(
|
||||
cls.__name__,
|
||||
cls.__doc__ or cls.__name__,
|
||||
get_type_hints(cls),
|
||||
{f.name: f for f in dc.fields(cls)},
|
||||
self._resolve_type,
|
||||
)
|
||||
|
||||
def _generate_model_from_definition(self, model_def) -> List[str]:
|
||||
hints = {f.name: f.type_hint for f in model_def.fields}
|
||||
defaults = {f.name: f.default for f in model_def.fields}
|
||||
|
||||
class FakeField:
|
||||
def __init__(self, default):
|
||||
self.default = default
|
||||
|
||||
fields = {name: FakeField(defaults.get(name, dc.MISSING)) for name in hints}
|
||||
return _build_table(
|
||||
model_def.name,
|
||||
model_def.docstring or model_def.name,
|
||||
hints,
|
||||
fields,
|
||||
self._resolve_type,
|
||||
)
|
||||
|
||||
|
||||
def _build_table(name, docstring, hints, fields, resolve_type_fn):
|
||||
"""Build a SQLModel table class from field data."""
|
||||
table_name = _to_snake_plural(name)
|
||||
lines = [
|
||||
f"class {name}(SQLModel, table=True):",
|
||||
f' """{docstring.strip().split(chr(10))[0]}"""',
|
||||
f' __tablename__ = "{table_name}"',
|
||||
"",
|
||||
]
|
||||
|
||||
for field_name, type_hint in hints.items():
|
||||
if field_name.startswith("_"):
|
||||
continue
|
||||
|
||||
field = fields.get(field_name)
|
||||
default_val = dc.MISSING
|
||||
if field and field.default is not dc.MISSING:
|
||||
default_val = field.default
|
||||
|
||||
py_type = resolve_type_fn(type_hint, False)
|
||||
field_extra = _resolve_field(field_name, type_hint, default_val)
|
||||
lines.append(f" {field_name}: {py_type}{field_extra}")
|
||||
|
||||
return lines
|
||||
Reference in New Issue
Block a user