remove duplicated code
This commit is contained in:
173
modelgen/generator/prisma.py
Normal file
173
modelgen/generator/prisma.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Prisma Generator
|
||||
|
||||
Generates Prisma schema from model definitions.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, List, get_type_hints
|
||||
|
||||
from ..helpers import get_origin_name, get_type_name, unwrap_optional
|
||||
from ..loader.schema import EnumDefinition, ModelDefinition
|
||||
from ..types import PRISMA_SPECIAL, PRISMA_TYPES
|
||||
from .base import BaseGenerator
|
||||
|
||||
|
||||
class PrismaGenerator(BaseGenerator):
|
||||
"""Generates Prisma schema files."""
|
||||
|
||||
def file_extension(self) -> str:
|
||||
return ".prisma"
|
||||
|
||||
def generate(self, models, output_path: Path) -> None:
|
||||
"""Generate Prisma schema to output_path."""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Handle different input types
|
||||
if hasattr(models, "models"):
|
||||
# SchemaLoader
|
||||
content = self._generate_from_definitions(
|
||||
models.models, getattr(models, "enums", [])
|
||||
)
|
||||
elif isinstance(models, tuple):
|
||||
# (models, enums) tuple
|
||||
content = self._generate_from_definitions(models[0], models[1])
|
||||
elif isinstance(models, list):
|
||||
# List of dataclasses (MPR style)
|
||||
content = self._generate_from_dataclasses(models)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(models)}")
|
||||
|
||||
output_path.write_text(content)
|
||||
|
||||
def _generate_from_definitions(
|
||||
self, models: List[ModelDefinition], enums: List[EnumDefinition]
|
||||
) -> str:
|
||||
"""Generate from ModelDefinition objects."""
|
||||
lines = self._generate_header()
|
||||
|
||||
# Generate enums
|
||||
for enum_def in enums:
|
||||
lines.extend(self._generate_enum(enum_def))
|
||||
lines.append("")
|
||||
|
||||
# Generate models
|
||||
for model_def in models:
|
||||
lines.extend(self._generate_model_from_definition(model_def))
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_from_dataclasses(self, dataclasses: List[type]) -> str:
|
||||
"""Generate from Python dataclasses (MPR style)."""
|
||||
lines = self._generate_header()
|
||||
|
||||
# Collect and generate enums first
|
||||
enums_generated = set()
|
||||
for cls in dataclasses:
|
||||
hints = get_type_hints(cls)
|
||||
for type_hint in hints.values():
|
||||
base, _ = unwrap_optional(type_hint)
|
||||
if isinstance(base, type) and issubclass(base, Enum):
|
||||
if base.__name__ not in enums_generated:
|
||||
lines.extend(self._generate_enum_from_python(base))
|
||||
lines.append("")
|
||||
enums_generated.add(base.__name__)
|
||||
|
||||
# Generate models
|
||||
for cls in dataclasses:
|
||||
lines.extend(self._generate_model_from_dataclass(cls))
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_header(self) -> List[str]:
|
||||
"""Generate file header with datasource and generator."""
|
||||
return [
|
||||
"// Prisma Schema - GENERATED FILE",
|
||||
"//",
|
||||
"// Do not edit directly. Regenerate using modelgen.",
|
||||
"",
|
||||
"generator client {",
|
||||
' provider = "prisma-client-py"',
|
||||
"}",
|
||||
"",
|
||||
"datasource db {",
|
||||
' provider = "postgresql"',
|
||||
' url = env("DATABASE_URL")',
|
||||
"}",
|
||||
"",
|
||||
]
|
||||
|
||||
def _generate_enum(self, enum_def: EnumDefinition) -> List[str]:
|
||||
"""Generate Prisma enum from EnumDefinition."""
|
||||
lines = [f"enum {enum_def.name} {{"]
|
||||
for name, _ in enum_def.values:
|
||||
lines.append(f" {name}")
|
||||
lines.append("}")
|
||||
return lines
|
||||
|
||||
def _generate_enum_from_python(self, enum_cls: type) -> List[str]:
|
||||
"""Generate Prisma enum from Python Enum."""
|
||||
lines = [f"enum {enum_cls.__name__} {{"]
|
||||
for member in enum_cls:
|
||||
lines.append(f" {member.name}")
|
||||
lines.append("}")
|
||||
return lines
|
||||
|
||||
def _generate_model_from_definition(self, model_def: ModelDefinition) -> List[str]:
|
||||
"""Generate Prisma model from ModelDefinition."""
|
||||
lines = [f"model {model_def.name} {{"]
|
||||
|
||||
for field in model_def.fields:
|
||||
prisma_type = self._resolve_type(
|
||||
field.name, field.type_hint, field.optional
|
||||
)
|
||||
lines.append(f" {field.name} {prisma_type}")
|
||||
|
||||
lines.append("}")
|
||||
return lines
|
||||
|
||||
def _generate_model_from_dataclass(self, cls: type) -> List[str]:
|
||||
"""Generate Prisma model from a dataclass."""
|
||||
lines = [f"model {cls.__name__} {{"]
|
||||
|
||||
for name, type_hint in get_type_hints(cls).items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
prisma_type = self._resolve_type(name, type_hint, False)
|
||||
lines.append(f" {name} {prisma_type}")
|
||||
|
||||
lines.append("}")
|
||||
return lines
|
||||
|
||||
def _resolve_type(self, name: str, type_hint: Any, optional: bool) -> str:
|
||||
"""Resolve Python type to Prisma type string."""
|
||||
# Special fields
|
||||
if name in PRISMA_SPECIAL:
|
||||
return PRISMA_SPECIAL[name]
|
||||
|
||||
base, is_optional = unwrap_optional(type_hint)
|
||||
optional = optional or is_optional
|
||||
origin = get_origin_name(base)
|
||||
type_name = get_type_name(base)
|
||||
|
||||
# Container types
|
||||
if origin == "dict" or origin == "list":
|
||||
result = PRISMA_TYPES.get(origin, "Json")
|
||||
return f"{result}?" if optional else result
|
||||
|
||||
# UUID / datetime
|
||||
if type_name in ("UUID", "datetime"):
|
||||
result = PRISMA_TYPES.get(type_name, "String")
|
||||
return f"{result}?" if optional else result
|
||||
|
||||
# Enum
|
||||
if isinstance(base, type) and issubclass(base, Enum):
|
||||
result = base.__name__
|
||||
return f"{result}?" if optional else result
|
||||
|
||||
# Basic types
|
||||
result = PRISMA_TYPES.get(base, "String")
|
||||
return f"{result}?" if optional else result
|
||||
Reference in New Issue
Block a user