executor abstraction, graphene to strawberry

This commit is contained in:
2026-03-12 23:27:34 -03:00
parent 4e9d731cff
commit eaaf2ad60c
13 changed files with 796 additions and 276 deletions

View File

@@ -0,0 +1,220 @@
"""
Strawberry Generator
Generates strawberry type, input, and enum classes from model definitions.
Only generates type definitions — queries, mutations, and resolvers are hand-written.
"""
import dataclasses as dc
from enum import Enum
from pathlib import Path
from typing import Any, List, get_type_hints
from ..helpers import get_origin_name, get_type_name, unwrap_optional
from ..loader.schema import EnumDefinition, FieldDefinition, ModelDefinition
from ..types import STRAWBERRY_RESOLVERS
from .base import BaseGenerator
class StrawberryGenerator(BaseGenerator):
"""Generates strawberry type definition files."""
def file_extension(self) -> str:
return ".py"
def generate(self, models, output_path: Path) -> None:
"""Generate strawberry types to output_path."""
output_path.parent.mkdir(parents=True, exist_ok=True)
if hasattr(models, "models"):
# SchemaLoader
content = self._generate_from_definitions(
models.models,
getattr(models, "enums", []),
getattr(models, "api_models", []),
)
elif isinstance(models, tuple):
content = self._generate_from_definitions(models[0], models[1], [])
elif isinstance(models, list):
content = self._generate_from_dataclasses(models)
else:
raise ValueError(f"Unsupported input type: {type(models)}")
output_path.write_text(content)
def _generate_from_definitions(
self,
models: List[ModelDefinition],
enums: List[EnumDefinition],
api_models: List[ModelDefinition],
) -> str:
lines = self._generate_header()
for enum_def in enums:
lines.extend(self._generate_enum(enum_def))
lines.append("")
lines.append("")
for model_def in models:
lines.extend(self._generate_object_type(model_def))
lines.append("")
lines.append("")
for model_def in api_models:
if model_def.name.endswith("Request"):
lines.extend(self._generate_input_type(model_def))
else:
lines.extend(self._generate_object_type(model_def))
lines.append("")
lines.append("")
return "\n".join(lines).rstrip() + "\n"
def _generate_from_dataclasses(self, dataclasses: List[type]) -> str:
lines = self._generate_header()
enums_generated = set()
for cls in dataclasses:
hints = get_type_hints(cls)
for type_hint in hints.values():
base, _ = unwrap_optional(type_hint)
if isinstance(base, type) and issubclass(base, Enum):
if base.__name__ not in enums_generated:
lines.extend(self._generate_enum_from_python(base))
lines.append("")
lines.append("")
enums_generated.add(base.__name__)
for cls in dataclasses:
lines.extend(self._generate_object_type_from_dataclass(cls))
lines.append("")
lines.append("")
return "\n".join(lines).rstrip() + "\n"
def _generate_header(self) -> List[str]:
return [
'"""',
"Strawberry Types - GENERATED FILE",
"",
"Do not edit directly. Regenerate using modelgen.",
'"""',
"",
"import strawberry",
"from enum import Enum",
"from typing import List, Optional",
"from uuid import UUID",
"from datetime import datetime",
"from strawberry.scalars import JSON",
"",
"",
]
def _generate_enum(self, enum_def: EnumDefinition) -> List[str]:
lines = ["@strawberry.enum", f"class {enum_def.name}(Enum):"]
for name, value in enum_def.values:
lines.append(f' {name} = "{value}"')
return lines
def _generate_enum_from_python(self, enum_cls: type) -> List[str]:
lines = ["@strawberry.enum", f"class {enum_cls.__name__}(Enum):"]
for member in enum_cls:
lines.append(f' {member.name} = "{member.value}"')
return lines
def _generate_object_type(self, model_def: ModelDefinition) -> List[str]:
name = model_def.name
type_name = f"{name}Type" if not name.endswith("Type") else name
lines = ["@strawberry.type", f"class {type_name}:"]
if model_def.docstring:
doc = model_def.docstring.strip().split("\n")[0]
lines.append(f' """{doc}"""')
lines.append("")
if not model_def.fields:
lines.append(" pass")
else:
for field in model_def.fields:
type_str = self._resolve_type(field.type_hint, optional=True)
lines.append(f" {field.name}: {type_str} = None")
return lines
def _generate_input_type(self, model_def: ModelDefinition) -> List[str]:
name = model_def.name
if name.endswith("Request"):
input_name = name[: -len("Request")] + "Input"
else:
input_name = f"{name}Input"
lines = ["@strawberry.input", f"class {input_name}:"]
if model_def.docstring:
doc = model_def.docstring.strip().split("\n")[0]
lines.append(f' """{doc}"""')
lines.append("")
if not model_def.fields:
lines.append(" pass")
else:
# Required fields first, then optional/defaulted
required = []
optional = []
for field in model_def.fields:
has_default = field.default is not dc.MISSING
if not field.optional and not has_default:
required.append(field)
else:
optional.append(field)
for field in required:
type_str = self._resolve_type(field.type_hint, optional=False)
lines.append(f" {field.name}: {type_str}")
for field in optional:
has_default = field.default is not dc.MISSING
if has_default and not callable(field.default):
type_str = self._resolve_type(field.type_hint, optional=False)
lines.append(f" {field.name}: {type_str} = {field.default!r}")
else:
type_str = self._resolve_type(field.type_hint, optional=True)
lines.append(f" {field.name}: {type_str} = None")
return lines
def _generate_object_type_from_dataclass(self, cls: type) -> List[str]:
type_name = f"{cls.__name__}Type"
lines = ["@strawberry.type", f"class {type_name}:"]
hints = get_type_hints(cls)
for name, type_hint in hints.items():
if name.startswith("_"):
continue
type_str = self._resolve_type(type_hint, optional=True)
lines.append(f" {name}: {type_str} = None")
return lines
def _resolve_type(self, type_hint: Any, optional: bool) -> str:
"""Resolve Python type hint to a strawberry annotation string."""
base, is_optional = unwrap_optional(type_hint)
optional = optional or is_optional
origin = get_origin_name(base)
type_name = get_type_name(base)
resolver = (
STRAWBERRY_RESOLVERS.get(origin)
or STRAWBERRY_RESOLVERS.get(type_name)
or STRAWBERRY_RESOLVERS.get(base)
or (
STRAWBERRY_RESOLVERS["enum"]
if isinstance(base, type) and issubclass(base, Enum)
else None
)
)
inner = resolver(base) if resolver else "str"
if optional:
return f"Optional[{inner}]"
return inner