""" SQLModel Generator Generates SQLModel table classes from model definitions. Extends the Pydantic generator — SQLModel classes *are* Pydantic models with table=True and SQLAlchemy column config for JSON fields. """ import dataclasses as dc import re from enum import Enum from typing import Any, List, get_type_hints from ..helpers import get_origin_name, get_type_name, unwrap_optional from .pydantic import PydanticGenerator # --------------------------------------------------------------------------- # Field resolvers — each returns a Field() string or None to fall through # --------------------------------------------------------------------------- def _resolve_special(name, _base, _origin, _optional, _default): """id, created_at, updated_at get fixed Field() definitions.""" specials = { "id": "Field(default_factory=uuid4, primary_key=True)", "created_at": "Field(default_factory=datetime.utcnow)", "updated_at": "Field(default_factory=datetime.utcnow)", } return specials.get(name) def _resolve_json(name, _base, origin, _optional, _default): """Dict and List fields → sa_column=Column(JSON).""" mapping = { "dict": ("dict", "{}"), "list": ("list", "[]"), } entry = mapping.get(origin) if not entry: return None factory, server_default = entry return ( f"Field(default_factory={factory}, " f"sa_column=Column(JSON, nullable=False, server_default='{server_default}'))" ) def _resolve_indexed(name, _base, _origin, optional, _default): """Known indexed fields.""" indexed = {"source_asset_id", "parent_job_id", "job_id", "canonical_name"} if name not in indexed: return None if optional: return "Field(default=None, index=True)" return "Field(index=True)" def _resolve_optional(_name, _base, _origin, optional, _default): """Optional fields default to None.""" if optional: return "None" return None def _resolve_default(_name, _base, _origin, _optional, default): """Fields with explicit defaults. Enum before str (str enums are both).""" if default is dc.MISSING or default is None: return None if isinstance(default, Enum): return f'"{default.value}"' if isinstance(default, bool): return str(default) if isinstance(default, (int, float)): return str(default) if isinstance(default, str): return f'"{default}"' return None # Resolver chain — first non-None result wins _FIELD_RESOLVERS = [ _resolve_special, _resolve_json, _resolve_indexed, _resolve_optional, _resolve_default, ] def _resolve_field(name, type_hint, default): """Run the resolver chain for a field. Returns ' = ...' string.""" base, is_optional = unwrap_optional(type_hint) origin = get_origin_name(base) for resolver in _FIELD_RESOLVERS: result = resolver(name, base, origin, is_optional, default) if result is not None: return f" = {result}" return "" def _to_snake_plural(name): """CamelCase → snake_case_plural for table names.""" s = re.sub(r"(?<=[a-z])(?=[A-Z])", "_", name).lower() if s.endswith("y") and not s.endswith("ey"): return s[:-1] + "ies" if s.endswith("s"): return s + "es" return s + "s" _HEADER = [ '"""', "SQLModel Table Models - GENERATED FILE", "", "Do not edit directly. Regenerate using modelgen.", '"""', "", "from datetime import datetime", "from enum import Enum", "from typing import Any, Dict, List, Optional", "from uuid import UUID, uuid4", "", "from sqlmodel import SQLModel, Field, Column", "from sqlalchemy import JSON", "", ] class SQLModelGenerator(PydanticGenerator): """Generates SQLModel table classes.""" def _generate_header(self) -> List[str]: return list(_HEADER) def _generate_model_from_dataclass(self, cls: type) -> List[str]: return _build_table( cls.__name__, cls.__doc__ or cls.__name__, get_type_hints(cls), {f.name: f for f in dc.fields(cls)}, self._resolve_type, ) def _generate_model_from_definition(self, model_def) -> List[str]: hints = {f.name: f.type_hint for f in model_def.fields} defaults = {f.name: f.default for f in model_def.fields} class FakeField: def __init__(self, default): self.default = default fields = {name: FakeField(defaults.get(name, dc.MISSING)) for name in hints} return _build_table( model_def.name, model_def.docstring or model_def.name, hints, fields, self._resolve_type, ) def _build_table(name, docstring, hints, fields, resolve_type_fn): """Build a SQLModel table class from field data.""" table_name = _to_snake_plural(name) lines = [ f"class {name}(SQLModel, table=True):", f' """{docstring.strip().split(chr(10))[0]}"""', f' __tablename__ = "{table_name}"', "", ] for field_name, type_hint in hints.items(): if field_name.startswith("_"): continue field = fields.get(field_name) default_val = dc.MISSING if field and field.default is not dc.MISSING: default_val = field.default py_type = resolve_type_fn(type_hint, False) field_extra = _resolve_field(field_name, type_hint, default_val) lines.append(f" {field_name}: {py_type}{field_extra}") return lines