""" Strawberry Generator Generates strawberry type, input, and enum classes from model definitions. Only generates type definitions — queries, mutations, and resolvers are hand-written. """ import dataclasses as dc from enum import Enum from pathlib import Path from typing import Any, List, get_type_hints from ..helpers import get_origin_name, get_type_name, unwrap_optional from ..loader.schema import EnumDefinition, FieldDefinition, ModelDefinition from ..types import STRAWBERRY_RESOLVERS from .base import BaseGenerator class StrawberryGenerator(BaseGenerator): """Generates strawberry type definition files.""" def file_extension(self) -> str: return ".py" def generate(self, models, output_path: Path) -> None: """Generate strawberry types to output_path.""" output_path.parent.mkdir(parents=True, exist_ok=True) if hasattr(models, "models"): # SchemaLoader content = self._generate_from_definitions( models.models, getattr(models, "enums", []), getattr(models, "api_models", []), ) elif isinstance(models, tuple): content = self._generate_from_definitions(models[0], models[1], []) elif isinstance(models, list): content = self._generate_from_dataclasses(models) else: raise ValueError(f"Unsupported input type: {type(models)}") output_path.write_text(content) def _generate_from_definitions( self, models: List[ModelDefinition], enums: List[EnumDefinition], api_models: List[ModelDefinition], ) -> str: lines = self._generate_header() for enum_def in enums: lines.extend(self._generate_enum(enum_def)) lines.append("") lines.append("") for model_def in models: lines.extend(self._generate_object_type(model_def)) lines.append("") lines.append("") for model_def in api_models: if model_def.name.endswith("Request"): lines.extend(self._generate_input_type(model_def)) else: lines.extend(self._generate_object_type(model_def)) lines.append("") lines.append("") return "\n".join(lines).rstrip() + "\n" def _generate_from_dataclasses(self, dataclasses: List[type]) -> str: lines = self._generate_header() enums_generated = set() for cls in dataclasses: hints = get_type_hints(cls) for type_hint in hints.values(): base, _ = unwrap_optional(type_hint) if isinstance(base, type) and issubclass(base, Enum): if base.__name__ not in enums_generated: lines.extend(self._generate_enum_from_python(base)) lines.append("") lines.append("") enums_generated.add(base.__name__) for cls in dataclasses: lines.extend(self._generate_object_type_from_dataclass(cls)) lines.append("") lines.append("") return "\n".join(lines).rstrip() + "\n" def _generate_header(self) -> List[str]: return [ '"""', "Strawberry Types - GENERATED FILE", "", "Do not edit directly. Regenerate using modelgen.", '"""', "", "import strawberry", "from enum import Enum", "from typing import List, Optional", "from uuid import UUID", "from datetime import datetime", "from strawberry.scalars import JSON", "", "", ] def _generate_enum(self, enum_def: EnumDefinition) -> List[str]: lines = ["@strawberry.enum", f"class {enum_def.name}(Enum):"] for name, value in enum_def.values: lines.append(f' {name} = "{value}"') return lines def _generate_enum_from_python(self, enum_cls: type) -> List[str]: lines = ["@strawberry.enum", f"class {enum_cls.__name__}(Enum):"] for member in enum_cls: lines.append(f' {member.name} = "{member.value}"') return lines def _generate_object_type(self, model_def: ModelDefinition) -> List[str]: name = model_def.name type_name = f"{name}Type" if not name.endswith("Type") else name lines = ["@strawberry.type", f"class {type_name}:"] if model_def.docstring: doc = model_def.docstring.strip().split("\n")[0] lines.append(f' """{doc}"""') lines.append("") if not model_def.fields: lines.append(" pass") else: for field in model_def.fields: type_str = self._resolve_type(field.type_hint, optional=True) lines.append(f" {field.name}: {type_str} = None") return lines def _generate_input_type(self, model_def: ModelDefinition) -> List[str]: name = model_def.name if name.endswith("Request"): input_name = name[: -len("Request")] + "Input" else: input_name = f"{name}Input" lines = ["@strawberry.input", f"class {input_name}:"] if model_def.docstring: doc = model_def.docstring.strip().split("\n")[0] lines.append(f' """{doc}"""') lines.append("") if not model_def.fields: lines.append(" pass") else: # Required fields first, then optional/defaulted required = [] optional = [] for field in model_def.fields: has_default = field.default is not dc.MISSING if not field.optional and not has_default: required.append(field) else: optional.append(field) for field in required: type_str = self._resolve_type(field.type_hint, optional=False) lines.append(f" {field.name}: {type_str}") for field in optional: has_default = field.default is not dc.MISSING if has_default and not callable(field.default): type_str = self._resolve_type(field.type_hint, optional=False) lines.append(f" {field.name}: {type_str} = {field.default!r}") else: type_str = self._resolve_type(field.type_hint, optional=True) lines.append(f" {field.name}: {type_str} = None") return lines def _generate_object_type_from_dataclass(self, cls: type) -> List[str]: type_name = f"{cls.__name__}Type" lines = ["@strawberry.type", f"class {type_name}:"] hints = get_type_hints(cls) for name, type_hint in hints.items(): if name.startswith("_"): continue type_str = self._resolve_type(type_hint, optional=True) lines.append(f" {name}: {type_str} = None") return lines def _resolve_type(self, type_hint: Any, optional: bool) -> str: """Resolve Python type hint to a strawberry annotation string.""" base, is_optional = unwrap_optional(type_hint) optional = optional or is_optional origin = get_origin_name(base) type_name = get_type_name(base) resolver = ( STRAWBERRY_RESOLVERS.get(origin) or STRAWBERRY_RESOLVERS.get(type_name) or STRAWBERRY_RESOLVERS.get(base) or ( STRAWBERRY_RESOLVERS["enum"] if isinstance(base, type) and issubclass(base, Enum) else None ) ) inner = resolver(base) if resolver else "str" if optional: return f"Optional[{inner}]" return inner