182 lines
5.4 KiB
Python
182 lines
5.4 KiB
Python
"""
|
|
SQLModel Generator
|
|
|
|
Generates SQLModel table classes from model definitions.
|
|
Extends the Pydantic generator — SQLModel classes *are* Pydantic models
|
|
with table=True and SQLAlchemy column config for JSON fields.
|
|
"""
|
|
|
|
import dataclasses as dc
|
|
import re
|
|
from enum import Enum
|
|
from typing import Any, List, get_type_hints
|
|
|
|
from ..helpers import get_origin_name, get_type_name, unwrap_optional
|
|
from .pydantic import PydanticGenerator
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Field resolvers — each returns a Field() string or None to fall through
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _resolve_special(name, _base, _origin, _optional, _default):
|
|
"""id, created_at, updated_at get fixed Field() definitions."""
|
|
specials = {
|
|
"id": "Field(default_factory=uuid4, primary_key=True)",
|
|
"created_at": "Field(default_factory=datetime.utcnow)",
|
|
"updated_at": "Field(default_factory=datetime.utcnow)",
|
|
}
|
|
return specials.get(name)
|
|
|
|
|
|
def _resolve_json(name, _base, origin, _optional, _default):
|
|
"""Dict and List fields → sa_column=Column(JSON)."""
|
|
mapping = {
|
|
"dict": ("dict", "{}"),
|
|
"list": ("list", "[]"),
|
|
}
|
|
entry = mapping.get(origin)
|
|
if not entry:
|
|
return None
|
|
factory, server_default = entry
|
|
return (
|
|
f"Field(default_factory={factory}, "
|
|
f"sa_column=Column(JSON, nullable=False, server_default='{server_default}'))"
|
|
)
|
|
|
|
|
|
def _resolve_indexed(name, _base, _origin, optional, _default):
|
|
"""Known indexed fields."""
|
|
indexed = {"source_asset_id", "parent_job_id", "job_id", "canonical_name"}
|
|
if name not in indexed:
|
|
return None
|
|
if optional:
|
|
return "Field(default=None, index=True)"
|
|
return "Field(index=True)"
|
|
|
|
|
|
def _resolve_optional(_name, _base, _origin, optional, _default):
|
|
"""Optional fields default to None."""
|
|
if optional:
|
|
return "None"
|
|
return None
|
|
|
|
|
|
def _resolve_default(_name, _base, _origin, _optional, default):
|
|
"""Fields with explicit defaults. Enum before str (str enums are both)."""
|
|
if default is dc.MISSING or default is None:
|
|
return None
|
|
if isinstance(default, Enum):
|
|
return f'"{default.value}"'
|
|
if isinstance(default, bool):
|
|
return str(default)
|
|
if isinstance(default, (int, float)):
|
|
return str(default)
|
|
if isinstance(default, str):
|
|
return f'"{default}"'
|
|
return None
|
|
|
|
|
|
# Resolver chain — first non-None result wins
|
|
_FIELD_RESOLVERS = [
|
|
_resolve_special,
|
|
_resolve_json,
|
|
_resolve_indexed,
|
|
_resolve_optional,
|
|
_resolve_default,
|
|
]
|
|
|
|
|
|
def _resolve_field(name, type_hint, default):
|
|
"""Run the resolver chain for a field. Returns ' = ...' string."""
|
|
base, is_optional = unwrap_optional(type_hint)
|
|
origin = get_origin_name(base)
|
|
|
|
for resolver in _FIELD_RESOLVERS:
|
|
result = resolver(name, base, origin, is_optional, default)
|
|
if result is not None:
|
|
return f" = {result}"
|
|
return ""
|
|
|
|
|
|
def _to_snake(name):
|
|
"""CamelCase → snake_case for table names."""
|
|
return re.sub(r"(?<=[a-z])(?=[A-Z])", "_", name).lower()
|
|
|
|
|
|
_HEADER = [
|
|
'"""',
|
|
"SQLModel Table Models - GENERATED FILE",
|
|
"",
|
|
"Do not edit directly. Regenerate using modelgen.",
|
|
'"""',
|
|
"",
|
|
"from datetime import datetime",
|
|
"from enum import Enum",
|
|
"from typing import Any, Dict, List, Optional",
|
|
"from uuid import UUID, uuid4",
|
|
"",
|
|
"from sqlmodel import SQLModel, Field, Column",
|
|
"from sqlalchemy import JSON",
|
|
"",
|
|
]
|
|
|
|
|
|
class SQLModelGenerator(PydanticGenerator):
|
|
"""Generates SQLModel table classes."""
|
|
|
|
def _generate_header(self) -> List[str]:
|
|
return list(_HEADER)
|
|
|
|
def _generate_model_from_dataclass(self, cls: type) -> List[str]:
|
|
return _build_table(
|
|
cls.__name__,
|
|
cls.__doc__ or cls.__name__,
|
|
get_type_hints(cls),
|
|
{f.name: f for f in dc.fields(cls)},
|
|
self._resolve_type,
|
|
)
|
|
|
|
def _generate_model_from_definition(self, model_def) -> List[str]:
|
|
hints = {f.name: f.type_hint for f in model_def.fields}
|
|
defaults = {f.name: f.default for f in model_def.fields}
|
|
|
|
class FakeField:
|
|
def __init__(self, default):
|
|
self.default = default
|
|
|
|
fields = {name: FakeField(defaults.get(name, dc.MISSING)) for name in hints}
|
|
return _build_table(
|
|
model_def.name,
|
|
model_def.docstring or model_def.name,
|
|
hints,
|
|
fields,
|
|
self._resolve_type,
|
|
)
|
|
|
|
|
|
def _build_table(name, docstring, hints, fields, resolve_type_fn):
|
|
"""Build a SQLModel table class from field data."""
|
|
table_name = _to_snake(name)
|
|
lines = [
|
|
f"class {name}(SQLModel, table=True):",
|
|
f' """{docstring.strip().split(chr(10))[0]}"""',
|
|
f' __tablename__ = "{table_name}"',
|
|
"",
|
|
]
|
|
|
|
for field_name, type_hint in hints.items():
|
|
if field_name.startswith("_"):
|
|
continue
|
|
|
|
field = fields.get(field_name)
|
|
default_val = dc.MISSING
|
|
if field and field.default is not dc.MISSING:
|
|
default_val = field.default
|
|
|
|
py_type = resolve_type_fn(type_hint, False)
|
|
field_extra = _resolve_field(field_name, type_hint, default_val)
|
|
lines.append(f" {field_name}: {py_type}{field_extra}")
|
|
|
|
return lines
|