diff --git a/soleprint/station/tools/modelgen/__main__.py b/soleprint/station/tools/modelgen/__main__.py index d834c71..aa3331d 100644 --- a/soleprint/station/tools/modelgen/__main__.py +++ b/soleprint/station/tools/modelgen/__main__.py @@ -20,6 +20,7 @@ Usage: python -m soleprint.station.tools.modelgen from-config -c config.json -o models.py python -m soleprint.station.tools.modelgen from-schema -o models/ --targets pydantic,typescript python -m soleprint.station.tools.modelgen extract --source /path/to/django --targets pydantic + python -m soleprint.station.tools.modelgen generate --config schema/modelgen.json """ import argparse @@ -72,10 +73,24 @@ def cmd_from_schema(args): print("that exports DATACLASSES and ENUMS lists.", file=sys.stderr) sys.exit(1) - print(f"Loading schema: {schema_path}") - schema = load_schema(schema_path) + # Parse include groups + include = None + if args.include: + include = {g.strip() for g in args.include.split(",")} - print(f"Found {len(schema.models)} models, {len(schema.enums)} enums") + print(f"Loading schema: {schema_path}") + schema = load_schema(schema_path, include=include) + + loaded = [] + if schema.models: + loaded.append(f"{len(schema.models)} models") + if schema.enums: + loaded.append(f"{len(schema.enums)} enums") + if schema.api_models: + loaded.append(f"{len(schema.api_models)} api models") + if schema.grpc_messages: + loaded.append(f"{len(schema.grpc_messages)} grpc messages") + print(f"Found {', '.join(loaded)}") # Parse targets targets = [t.strip() for t in args.targets.split(",")] @@ -163,6 +178,47 @@ def cmd_extract(args): print("Done!") +def cmd_generate(args): + """Generate all targets from a JSON config file.""" + import json + from .loader import load_schema + + config_path = Path(args.config) + if not config_path.exists(): + print(f"Error: Config file not found: {config_path}", file=sys.stderr) + sys.exit(1) + + with open(config_path) as f: + config = json.load(f) + + # Resolve paths relative to current working directory + schema_path = Path(config["schema"]) + if not schema_path.exists(): + print(f"Error: Schema folder not found: {schema_path}", file=sys.stderr) + sys.exit(1) + + print(f"Loading schema: {schema_path}") + + for target_conf in config["targets"]: + target = target_conf["target"] + output = Path(target_conf["output"]) + include = set(target_conf.get("include", [])) + name_map = target_conf.get("name_map", {}) + + if target not in GENERATORS: + print(f"Warning: Unknown target '{target}', skipping", file=sys.stderr) + continue + + # Load schema with this target's include filter + schema = load_schema(schema_path, include=include or None) + + generator = GENERATORS[target](name_map=name_map) + print(f"Generating {target} to: {output}") + generator.generate(schema, output) + + print("Done!") + + def cmd_list_formats(args): """List available output formats.""" print("Available output formats:") @@ -237,6 +293,12 @@ def main(): default="pydantic", help=f"Comma-separated output targets ({formats_str})", ) + schema_parser.add_argument( + "--include", + type=str, + default=None, + help="Comma-separated model groups to include (dataclasses,enums,api,grpc). Default: all.", + ) schema_parser.set_defaults(func=cmd_from_schema) # extract command @@ -275,6 +337,21 @@ def main(): ) extract_parser.set_defaults(func=cmd_extract) + + # generate command (config-driven multi-target) + gen_parser = subparsers.add_parser( + "generate", + help="Generate all targets from a JSON config file", + ) + gen_parser.add_argument( + "--config", + "-c", + type=str, + required=True, + help="Path to generation config file (e.g., schema/modelgen.json)", + ) + gen_parser.set_defaults(func=cmd_generate) + # list-formats command formats_parser = subparsers.add_parser( "list-formats", diff --git a/soleprint/station/tools/modelgen/generator/__init__.py b/soleprint/station/tools/modelgen/generator/__init__.py index 8e374dc..5abc0ac 100644 --- a/soleprint/station/tools/modelgen/generator/__init__.py +++ b/soleprint/station/tools/modelgen/generator/__init__.py @@ -7,12 +7,14 @@ Supported generators: - TypeScriptGenerator: TypeScript interfaces - ProtobufGenerator: Protocol Buffer definitions - PrismaGenerator: Prisma schema +- GrapheneGenerator: Graphene ObjectType/InputObjectType classes """ from typing import Dict, Type from .base import BaseGenerator from .django import DjangoGenerator +from .graphene import GrapheneGenerator from .prisma import PrismaGenerator from .protobuf import ProtobufGenerator from .pydantic import PydanticGenerator @@ -27,12 +29,14 @@ GENERATORS: Dict[str, Type[BaseGenerator]] = { "protobuf": ProtobufGenerator, "proto": ProtobufGenerator, # Alias "prisma": PrismaGenerator, + "graphene": GrapheneGenerator, } __all__ = [ "BaseGenerator", "PydanticGenerator", "DjangoGenerator", + "GrapheneGenerator", "TypeScriptGenerator", "ProtobufGenerator", "PrismaGenerator", diff --git a/soleprint/station/tools/modelgen/generator/base.py b/soleprint/station/tools/modelgen/generator/base.py index e3bd5ea..0c4ce9a 100644 --- a/soleprint/station/tools/modelgen/generator/base.py +++ b/soleprint/station/tools/modelgen/generator/base.py @@ -6,12 +6,19 @@ Abstract base class for all code generators. from abc import ABC, abstractmethod from pathlib import Path -from typing import Any +from typing import Any, Dict class BaseGenerator(ABC): """Abstract base for code generators.""" + def __init__(self, name_map: Dict[str, str] = None): + self.name_map = name_map or {} + + def map_name(self, name: str) -> str: + """Apply name_map to a model name.""" + return self.name_map.get(name, name) + @abstractmethod def generate(self, models: Any, output_path: Path) -> None: """Generate code for the given models to the specified path.""" diff --git a/soleprint/station/tools/modelgen/generator/django.py b/soleprint/station/tools/modelgen/generator/django.py index 418c2ff..e4368b1 100644 --- a/soleprint/station/tools/modelgen/generator/django.py +++ b/soleprint/station/tools/modelgen/generator/django.py @@ -217,12 +217,14 @@ class DjangoGenerator(BaseGenerator): # Enum if isinstance(base, type) and issubclass(base, Enum): + enum_name = base.__name__ extra = [] if optional: extra.append("null=True, blank=True") if default is not dc.MISSING and isinstance(default, Enum): - extra.append(f"default=Status.{default.name}") + extra.append(f"default={enum_name}.{default.name}") return DJANGO_TYPES["enum"].format( + enum_name=enum_name, opts=", " + ", ".join(extra) if extra else "" ) diff --git a/soleprint/station/tools/modelgen/generator/graphene.py b/soleprint/station/tools/modelgen/generator/graphene.py new file mode 100644 index 0000000..503bbba --- /dev/null +++ b/soleprint/station/tools/modelgen/generator/graphene.py @@ -0,0 +1,236 @@ +""" +Graphene Generator + +Generates graphene ObjectType and InputObjectType classes from model definitions. +Only generates type definitions — queries, mutations, and resolvers are hand-written. +""" + +from enum import Enum +from pathlib import Path +from typing import Any, List, get_type_hints + +from ..helpers import get_origin_name, get_type_name, unwrap_optional +from ..loader.schema import EnumDefinition, FieldDefinition, ModelDefinition +from ..types import GRAPHENE_RESOLVERS +from .base import BaseGenerator + + +class GrapheneGenerator(BaseGenerator): + """Generates graphene type definition files.""" + + def file_extension(self) -> str: + return ".py" + + def generate(self, models, output_path: Path) -> None: + """Generate graphene types to output_path.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + + if hasattr(models, "models"): + # SchemaLoader + content = self._generate_from_definitions( + models.models, + getattr(models, "enums", []), + getattr(models, "api_models", []), + ) + elif isinstance(models, tuple): + content = self._generate_from_definitions(models[0], models[1], []) + elif isinstance(models, list): + content = self._generate_from_dataclasses(models) + else: + raise ValueError(f"Unsupported input type: {type(models)}") + + output_path.write_text(content) + + def _generate_from_definitions( + self, + models: List[ModelDefinition], + enums: List[EnumDefinition], + api_models: List[ModelDefinition], + ) -> str: + """Generate from ModelDefinition objects.""" + lines = self._generate_header() + + # Generate enums as graphene.Enum + for enum_def in enums: + lines.extend(self._generate_enum(enum_def)) + lines.append("") + lines.append("") + + # Generate domain models as ObjectType + for model_def in models: + lines.extend(self._generate_object_type(model_def)) + lines.append("") + lines.append("") + + # Generate API models — request types as InputObjectType, others as ObjectType + for model_def in api_models: + if model_def.name.endswith("Request"): + lines.extend(self._generate_input_type(model_def)) + else: + lines.extend(self._generate_object_type(model_def)) + lines.append("") + lines.append("") + + return "\n".join(lines).rstrip() + "\n" + + def _generate_from_dataclasses(self, dataclasses: List[type]) -> str: + """Generate from Python dataclasses.""" + lines = self._generate_header() + + enums_generated = set() + for cls in dataclasses: + hints = get_type_hints(cls) + for type_hint in hints.values(): + base, _ = unwrap_optional(type_hint) + if isinstance(base, type) and issubclass(base, Enum): + if base.__name__ not in enums_generated: + lines.extend(self._generate_enum_from_python(base)) + lines.append("") + lines.append("") + enums_generated.add(base.__name__) + + for cls in dataclasses: + lines.extend(self._generate_object_type_from_dataclass(cls)) + lines.append("") + lines.append("") + + return "\n".join(lines).rstrip() + "\n" + + def _generate_header(self) -> List[str]: + return [ + '"""', + "Graphene Types - GENERATED FILE", + "", + "Do not edit directly. Regenerate using modelgen.", + '"""', + "", + "import graphene", + "", + "", + ] + + def _generate_enum(self, enum_def: EnumDefinition) -> List[str]: + """Generate graphene.Enum from EnumDefinition.""" + lines = [f"class {enum_def.name}(graphene.Enum):"] + for name, value in enum_def.values: + lines.append(f' {name} = "{value}"') + return lines + + def _generate_enum_from_python(self, enum_cls: type) -> List[str]: + """Generate graphene.Enum from Python Enum.""" + lines = [f"class {enum_cls.__name__}(graphene.Enum):"] + for member in enum_cls: + lines.append(f' {member.name} = "{member.value}"') + return lines + + def _generate_object_type(self, model_def: ModelDefinition) -> List[str]: + """Generate graphene.ObjectType from ModelDefinition.""" + name = model_def.name + # Append Type suffix if not already present + type_name = f"{name}Type" if not name.endswith("Type") else name + + lines = [f"class {type_name}(graphene.ObjectType):"] + if model_def.docstring: + doc = model_def.docstring.strip().split("\n")[0] + lines.append(f' """{doc}"""') + lines.append("") + + if not model_def.fields: + lines.append(" pass") + else: + for field in model_def.fields: + graphene_type = self._resolve_type(field.type_hint, field.optional) + lines.append(f" {field.name} = {graphene_type}") + + return lines + + def _generate_input_type(self, model_def: ModelDefinition) -> List[str]: + """Generate graphene.InputObjectType from ModelDefinition.""" + import dataclasses as dc + + name = model_def.name + # Convert FooRequest -> FooInput + if name.endswith("Request"): + input_name = name[: -len("Request")] + "Input" + else: + input_name = f"{name}Input" + + lines = [f"class {input_name}(graphene.InputObjectType):"] + if model_def.docstring: + doc = model_def.docstring.strip().split("\n")[0] + lines.append(f' """{doc}"""') + lines.append("") + + if not model_def.fields: + lines.append(" pass") + else: + for field in model_def.fields: + graphene_type = self._resolve_type(field.type_hint, field.optional) + # Required only if not optional AND no default value + has_default = field.default is not dc.MISSING + if not field.optional and not has_default: + graphene_type = self._make_required(graphene_type) + elif has_default and not field.optional: + graphene_type = self._add_default(graphene_type, field.default) + lines.append(f" {field.name} = {graphene_type}") + + return lines + + def _generate_object_type_from_dataclass(self, cls: type) -> List[str]: + """Generate graphene.ObjectType from a dataclass.""" + import dataclasses as dc + + type_name = f"{cls.__name__}Type" + lines = [f"class {type_name}(graphene.ObjectType):"] + + hints = get_type_hints(cls) + for name, type_hint in hints.items(): + if name.startswith("_"): + continue + graphene_type = self._resolve_type(type_hint, False) + lines.append(f" {name} = {graphene_type}") + + return lines + + def _resolve_type(self, type_hint: Any, optional: bool) -> str: + """Resolve Python type to graphene field call string.""" + base, is_optional = unwrap_optional(type_hint) + optional = optional or is_optional + origin = get_origin_name(base) + type_name = get_type_name(base) + + # Look up resolver + resolver = ( + GRAPHENE_RESOLVERS.get(origin) + or GRAPHENE_RESOLVERS.get(type_name) + or GRAPHENE_RESOLVERS.get(base) + or ( + GRAPHENE_RESOLVERS["enum"] + if isinstance(base, type) and issubclass(base, Enum) + else None + ) + ) + + result = resolver(base) if resolver else "graphene.String" + + # List types already have () syntax from resolver + if result.startswith("graphene.List("): + return result + + # Scalar types: add () call + return f"{result}()" + + def _make_required(self, field_str: str) -> str: + """Add required=True to a graphene field.""" + if field_str.endswith("()"): + return field_str[:-1] + "required=True)" + return field_str + + def _add_default(self, field_str: str, default: Any) -> str: + """Add default_value to a graphene field.""" + if callable(default): + # default_factory — skip, graphene doesn't support factories + return field_str + if field_str.endswith("()"): + return field_str[:-1] + f"default_value={default!r})" + return field_str diff --git a/soleprint/station/tools/modelgen/generator/pydantic.py b/soleprint/station/tools/modelgen/generator/pydantic.py index f2410b4..c2676f2 100644 --- a/soleprint/station/tools/modelgen/generator/pydantic.py +++ b/soleprint/station/tools/modelgen/generator/pydantic.py @@ -2,8 +2,12 @@ Pydantic Generator Generates Pydantic BaseModel classes from model definitions. +Supports two output modes: +- File output: flat models (backwards compatible) +- Directory output: CRUD variants (Create/Update/Response) per model """ +import dataclasses as dc from enum import Enum from pathlib import Path from typing import Any, List, get_type_hints @@ -13,6 +17,13 @@ from ..loader.schema import EnumDefinition, FieldDefinition, ModelDefinition from ..types import PYDANTIC_RESOLVERS from .base import BaseGenerator +# Fields to skip per CRUD variant +SKIP_FIELDS = { + "Create": {"id", "created_at", "updated_at", "status", "error_message"}, + "Update": {"id", "created_at", "updated_at"}, + "Response": set(), +} + class PydanticGenerator(BaseGenerator): """Generates Pydantic model files.""" @@ -21,52 +32,187 @@ class PydanticGenerator(BaseGenerator): return ".py" def generate(self, models, output_path: Path) -> None: - """Generate Pydantic models to output_path.""" + """Generate Pydantic models to output_path. + + If output_path is a directory (or doesn't end in .py), generate + multi-file CRUD variants. Otherwise, generate flat models to a + single file. + """ + output_path = Path(output_path) + + if output_path.suffix != ".py": + # Directory mode: CRUD variants + self._generate_crud_directory(models, output_path) + else: + # File mode: flat models (backwards compatible) + self._generate_flat_file(models, output_path) + + def _generate_flat_file(self, models, output_path: Path) -> None: + """Generate flat models to a single file (original behavior).""" output_path.parent.mkdir(parents=True, exist_ok=True) - # Detect input type and generate accordingly if hasattr(models, "get_shared_component"): - # ConfigLoader (soleprint config) content = self._generate_from_config(models) elif hasattr(models, "models"): - # SchemaLoader content = self._generate_from_definitions( models.models, getattr(models, "enums", []) ) elif isinstance(models, tuple): - # (models, enums) tuple from extractor content = self._generate_from_definitions(models[0], models[1]) elif isinstance(models, list): - # List of dataclasses (MPR style) content = self._generate_from_dataclasses(models) else: raise ValueError(f"Unsupported input type: {type(models)}") output_path.write_text(content) + def _generate_crud_directory(self, models, output_dir: Path) -> None: + """Generate CRUD variant files in a directory.""" + output_dir.mkdir(parents=True, exist_ok=True) + + if hasattr(models, "models"): + model_defs = models.models + enum_defs = getattr(models, "enums", []) + elif isinstance(models, tuple): + model_defs = models[0] + enum_defs = models[1] + else: + raise ValueError(f"Unsupported input type for CRUD mode: {type(models)}") + + # base.py + base_content = "\n".join([ + '"""Pydantic Base Schema - GENERATED FILE"""', + "", + "from pydantic import BaseModel, ConfigDict", + "", + "", + "class BaseSchema(BaseModel):", + ' """Base schema with ORM mode."""', + " model_config = ConfigDict(from_attributes=True)", + "", + ]) + (output_dir / "base.py").write_text(base_content) + + # Per-model files + imports = ["from .base import BaseSchema"] + all_exports = ['"BaseSchema"'] + + for model_def in model_defs: + mapped = self.map_name(model_def.name) + module_name = mapped.lower() + + lines = [ + f'"""{model_def.name} Schemas - GENERATED FILE"""', + "", + "from datetime import datetime", + "from enum import Enum", + "from typing import Any, Dict, List, Optional", + "from uuid import UUID", + "", + "from .base import BaseSchema", + "", + ] + + # Inline enums used by this model + model_enums = self._collect_model_enums(model_def, enum_defs) + for enum_def in model_enums: + lines.append("") + lines.extend(self._generate_enum(enum_def)) + lines.append("") + + # CRUD variants + for suffix in ["Create", "Update", "Response"]: + lines.append("") + lines.extend(self._generate_crud_model(model_def, mapped, suffix)) + + lines.append("") + content = "\n".join(lines) + (output_dir / f"{module_name}.py").write_text(content) + + # Track imports + imports.append( + f"from .{module_name} import {mapped}Create, {mapped}Update, {mapped}Response" + ) + all_exports.extend([ + f'"{mapped}Create"', f'"{mapped}Update"', f'"{mapped}Response"' + ]) + + for enum_def in model_enums: + imports.append(f"from .{module_name} import {enum_def.name}") + all_exports.append(f'"{enum_def.name}"') + + # __init__.py + init_content = "\n".join([ + '"""API Schemas - GENERATED FILE"""', + "", + *imports, + "", + f"__all__ = [{', '.join(all_exports)}]", + "", + ]) + (output_dir / "__init__.py").write_text(init_content) + + def _collect_model_enums( + self, model_def: ModelDefinition, enum_defs: List[EnumDefinition] + ) -> List[EnumDefinition]: + """Find enums referenced by a model's fields.""" + enum_names = set() + for field in model_def.fields: + base, _ = unwrap_optional(field.type_hint) + if isinstance(base, type) and issubclass(base, Enum): + enum_names.add(base.__name__) + return [e for e in enum_defs if e.name in enum_names] + + def _generate_crud_model( + self, model_def: ModelDefinition, mapped_name: str, suffix: str + ) -> List[str]: + """Generate a single CRUD variant (Create/Update/Response).""" + class_name = f"{mapped_name}{suffix}" + skip = SKIP_FIELDS.get(suffix, set()) + + lines = [ + f"class {class_name}(BaseSchema):", + f' """{class_name} schema."""', + ] + + has_fields = False + for field in model_def.fields: + if field.name.startswith("_") or field.name in skip: + continue + + has_fields = True + py_type = self._resolve_type(field.type_hint, field.optional) + + # Update variant: all fields optional + if suffix == "Update" and "Optional" not in py_type: + py_type = f"Optional[{py_type}]" + + default = self._format_default(field.default, "Optional" in py_type) + lines.append(f" {field.name}: {py_type}{default}") + + if not has_fields: + lines.append(" pass") + + return lines + + # ========================================================================= + # Flat file generation (original behavior) + # ========================================================================= + def _generate_from_definitions( self, models: List[ModelDefinition], enums: List[EnumDefinition] ) -> str: - """Generate from ModelDefinition objects (schema/extract mode).""" lines = self._generate_header() - - # Generate enums for enum_def in enums: lines.extend(self._generate_enum(enum_def)) lines.append("") - - # Generate models for model_def in models: lines.extend(self._generate_model_from_definition(model_def)) lines.append("") - return "\n".join(lines) def _generate_from_dataclasses(self, dataclasses: List[type]) -> str: - """Generate from Python dataclasses (MPR style).""" lines = self._generate_header() - - # Collect and generate enums first enums_generated = set() for cls in dataclasses: hints = get_type_hints(cls) @@ -77,16 +223,12 @@ class PydanticGenerator(BaseGenerator): lines.extend(self._generate_enum_from_python(base)) lines.append("") enums_generated.add(base.__name__) - - # Generate models for cls in dataclasses: lines.extend(self._generate_model_from_dataclass(cls)) lines.append("") - return "\n".join(lines) def _generate_header(self) -> List[str]: - """Generate file header.""" return [ '"""', "Pydantic Models - GENERATED FILE", @@ -104,27 +246,23 @@ class PydanticGenerator(BaseGenerator): ] def _generate_enum(self, enum_def: EnumDefinition) -> List[str]: - """Generate Pydantic enum from EnumDefinition.""" lines = [f"class {enum_def.name}(str, Enum):"] for name, value in enum_def.values: lines.append(f' {name} = "{value}"') return lines def _generate_enum_from_python(self, enum_cls: type) -> List[str]: - """Generate Pydantic enum from Python Enum.""" lines = [f"class {enum_cls.__name__}(str, Enum):"] for member in enum_cls: lines.append(f' {member.name} = "{member.value}"') return lines def _generate_model_from_definition(self, model_def: ModelDefinition) -> List[str]: - """Generate Pydantic model from ModelDefinition.""" docstring = model_def.docstring or model_def.name lines = [ f"class {model_def.name}(BaseModel):", f' """{docstring.strip().split(chr(10))[0]}"""', ] - if not model_def.fields: lines.append(" pass") else: @@ -132,46 +270,34 @@ class PydanticGenerator(BaseGenerator): py_type = self._resolve_type(field.type_hint, field.optional) default = self._format_default(field.default, field.optional) lines.append(f" {field.name}: {py_type}{default}") - return lines def _generate_model_from_dataclass(self, cls: type) -> List[str]: - """Generate Pydantic model from a dataclass.""" - import dataclasses as dc - docstring = cls.__doc__ or cls.__name__ lines = [ f"class {cls.__name__}(BaseModel):", f' """{docstring.strip().split(chr(10))[0]}"""', ] - hints = get_type_hints(cls) fields = {f.name: f for f in dc.fields(cls)} - for name, type_hint in hints.items(): if name.startswith("_"): continue - field = fields.get(name) default_val = dc.MISSING if field: if field.default is not dc.MISSING: default_val = field.default - py_type = self._resolve_type(type_hint, False) default = self._format_default(default_val, "Optional" in py_type) lines.append(f" {name}: {py_type}{default}") - return lines def _resolve_type(self, type_hint: Any, optional: bool) -> str: - """Resolve Python type to Pydantic type string.""" base, is_optional = unwrap_optional(type_hint) optional = optional or is_optional origin = get_origin_name(base) type_name = get_type_name(base) - - # Look up resolver resolver = ( PYDANTIC_RESOLVERS.get(origin) or PYDANTIC_RESOLVERS.get(type_name) @@ -182,14 +308,10 @@ class PydanticGenerator(BaseGenerator): else None ) ) - result = resolver(base) if resolver else "str" return f"Optional[{result}]" if optional else result def _format_default(self, default: Any, optional: bool) -> str: - """Format default value for field.""" - import dataclasses as dc - if optional: return " = None" if default is dc.MISSING or default is None: @@ -204,7 +326,6 @@ class PydanticGenerator(BaseGenerator): def _generate_from_config(self, config) -> str: """Generate from ConfigLoader (soleprint config.json mode).""" - # Get component names from config config_comp = config.get_shared_component("config") data_comp = config.get_shared_component("data") diff --git a/soleprint/station/tools/modelgen/generator/typescript.py b/soleprint/station/tools/modelgen/generator/typescript.py index fcca18f..e1cc5f8 100644 --- a/soleprint/station/tools/modelgen/generator/typescript.py +++ b/soleprint/station/tools/modelgen/generator/typescript.py @@ -26,9 +26,10 @@ class TypeScriptGenerator(BaseGenerator): # Handle different input types if hasattr(models, "models"): - # SchemaLoader + # SchemaLoader — include api_models if present + all_models = models.models + getattr(models, "api_models", []) content = self._generate_from_definitions( - models.models, getattr(models, "enums", []) + all_models, getattr(models, "enums", []) ) elif isinstance(models, tuple): # (models, enums) tuple diff --git a/soleprint/station/tools/modelgen/loader/schema.py b/soleprint/station/tools/modelgen/loader/schema.py index 360ca2b..78833f3 100644 --- a/soleprint/station/tools/modelgen/loader/schema.py +++ b/soleprint/station/tools/modelgen/loader/schema.py @@ -5,6 +5,7 @@ Loads Python dataclasses from a schema/ folder. Expects the folder to have an __init__.py that exports: - DATACLASSES: List of dataclass types to generate - ENUMS: List of Enum types to include +- API_MODELS: (optional) List of API request/response types - GRPC_MESSAGES: (optional) List of gRPC message types - GRPC_SERVICE: (optional) gRPC service definition dict """ @@ -60,12 +61,18 @@ class SchemaLoader: def __init__(self, schema_path: Path): self.schema_path = Path(schema_path) self.models: List[ModelDefinition] = [] + self.api_models: List[ModelDefinition] = [] self.enums: List[EnumDefinition] = [] self.grpc_messages: List[ModelDefinition] = [] self.grpc_service: Optional[GrpcServiceDefinition] = None - def load(self) -> "SchemaLoader": - """Load schema definitions from the schema folder.""" + def load(self, include: Optional[set] = None) -> "SchemaLoader": + """Load schema definitions from the schema folder. + + Args: + include: Set of groups to load (dataclasses, enums, api, grpc). + None means load all groups. + """ init_path = self.schema_path / "__init__.py" if not init_path.exists(): @@ -74,29 +81,41 @@ class SchemaLoader: # Import the schema module module = self._import_module(init_path) + load_all = include is None + # Extract DATACLASSES - dataclasses = getattr(module, "DATACLASSES", []) - for cls in dataclasses: - self.models.append(self._parse_dataclass(cls)) + if load_all or "dataclasses" in include: + dataclasses = getattr(module, "DATACLASSES", []) + for cls in dataclasses: + self.models.append(self._parse_dataclass(cls)) + + # Extract API_MODELS (request/response types) + if load_all or "api" in include: + api_models = getattr(module, "API_MODELS", []) + for cls in api_models: + self.api_models.append(self._parse_dataclass(cls)) # Extract ENUMS - enums = getattr(module, "ENUMS", []) - for enum_cls in enums: - self.enums.append(self._parse_enum(enum_cls)) + if load_all or "enums" in include: + enums = getattr(module, "ENUMS", []) + for enum_cls in enums: + self.enums.append(self._parse_enum(enum_cls)) # Extract GRPC_MESSAGES (optional) - grpc_messages = getattr(module, "GRPC_MESSAGES", []) - for cls in grpc_messages: - self.grpc_messages.append(self._parse_dataclass(cls)) + if load_all or "grpc" in include: + grpc_messages = getattr(module, "GRPC_MESSAGES", []) + for cls in grpc_messages: + self.grpc_messages.append(self._parse_dataclass(cls)) # Extract GRPC_SERVICE (optional) - grpc_service = getattr(module, "GRPC_SERVICE", None) - if grpc_service: - self.grpc_service = GrpcServiceDefinition( - package=grpc_service.get("package", "service"), - name=grpc_service.get("name", "Service"), - methods=grpc_service.get("methods", []), - ) + if load_all or "grpc" in include: + grpc_service = getattr(module, "GRPC_SERVICE", None) + if grpc_service: + self.grpc_service = GrpcServiceDefinition( + package=grpc_service.get("package", "service"), + name=grpc_service.get("name", "Service"), + methods=grpc_service.get("methods", []), + ) return self @@ -163,7 +182,7 @@ class SchemaLoader: return False -def load_schema(schema_path: str | Path) -> SchemaLoader: +def load_schema(schema_path: str | Path, include: Optional[set] = None) -> SchemaLoader: """Load schema definitions from folder.""" loader = SchemaLoader(schema_path) - return loader.load() + return loader.load(include=include) diff --git a/soleprint/station/tools/modelgen/types.py b/soleprint/station/tools/modelgen/types.py index b029437..cf35e48 100644 --- a/soleprint/station/tools/modelgen/types.py +++ b/soleprint/station/tools/modelgen/types.py @@ -22,7 +22,7 @@ DJANGO_TYPES: dict[Any, str] = { "list": "models.JSONField(default=list, blank=True)", "text": "models.TextField(blank=True, default='')", "bigint": "models.BigIntegerField({opts})", - "enum": "models.CharField(max_length=20, choices=Status.choices{opts})", + "enum": "models.CharField(max_length=20, choices={enum_name}.choices{opts})", } DJANGO_SPECIAL: dict[str, str] = { @@ -137,3 +137,36 @@ PRISMA_SPECIAL: dict[str, str] = { "created_at": "DateTime @default(now())", "updated_at": "DateTime @updatedAt", } + +# ============================================================================= +# Graphene Type Resolvers +# ============================================================================= + + +def _resolve_graphene_list(base: Any) -> str: + """Resolve graphene List type.""" + args = get_args(base) + if args: + inner = args[0] + if inner is str: + return "graphene.List(graphene.String)" + elif inner is int: + return "graphene.List(graphene.Int)" + elif inner is float: + return "graphene.List(graphene.Float)" + elif inner is bool: + return "graphene.List(graphene.Boolean)" + return "graphene.List(graphene.String)" + + +GRAPHENE_RESOLVERS: dict[Any, Callable[[Any], str]] = { + str: lambda _: "graphene.String", + int: lambda _: "graphene.Int", + float: lambda _: "graphene.Float", + bool: lambda _: "graphene.Boolean", + "UUID": lambda _: "graphene.UUID", + "datetime": lambda _: "graphene.DateTime", + "dict": lambda _: "graphene.JSONString", + "list": _resolve_graphene_list, + "enum": lambda base: f"graphene.String", # Enums exposed as strings in GQL +}