176 lines
5.3 KiB
Python
176 lines
5.3 KiB
Python
"""
|
|
Schema Loader
|
|
|
|
Loads Python dataclasses from a schema/ folder.
|
|
Expects the folder to have an __init__.py that exports:
|
|
- DATACLASSES: List of dataclass types to generate
|
|
- ENUMS: List of Enum types to include
|
|
- GRPC_MESSAGES: (optional) List of gRPC message types
|
|
- GRPC_SERVICE: (optional) gRPC service definition dict
|
|
"""
|
|
|
|
import dataclasses as dc
|
|
import importlib.util
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Type, get_type_hints
|
|
|
|
|
|
@dataclass
|
|
class FieldDefinition:
|
|
"""Represents a model field."""
|
|
|
|
name: str
|
|
type_hint: Any
|
|
default: Any = dc.MISSING
|
|
optional: bool = False
|
|
|
|
|
|
@dataclass
|
|
class ModelDefinition:
|
|
"""Represents a model/dataclass."""
|
|
|
|
name: str
|
|
fields: List[FieldDefinition]
|
|
docstring: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class EnumDefinition:
|
|
"""Represents an enum."""
|
|
|
|
name: str
|
|
values: List[tuple[str, str]] # (name, value) pairs
|
|
|
|
|
|
@dataclass
|
|
class GrpcServiceDefinition:
|
|
"""Represents a gRPC service."""
|
|
|
|
package: str
|
|
name: str
|
|
methods: List[Dict[str, Any]]
|
|
|
|
|
|
class SchemaLoader:
|
|
"""Loads model definitions from Python dataclasses in schema/ folder."""
|
|
|
|
def __init__(self, schema_path: Path):
|
|
self.schema_path = Path(schema_path)
|
|
self.models: List[ModelDefinition] = []
|
|
self.api_models: List[ModelDefinition] = []
|
|
self.enums: List[EnumDefinition] = []
|
|
self.grpc_messages: List[ModelDefinition] = []
|
|
self.grpc_service: Optional[GrpcServiceDefinition] = None
|
|
|
|
def load(self) -> "SchemaLoader":
|
|
"""Load schema definitions from the schema folder."""
|
|
init_path = self.schema_path / "__init__.py"
|
|
|
|
if not init_path.exists():
|
|
raise FileNotFoundError(f"Schema folder must have __init__.py: {init_path}")
|
|
|
|
# Import the schema module
|
|
module = self._import_module(init_path)
|
|
|
|
# Extract DATACLASSES
|
|
dataclasses = getattr(module, "DATACLASSES", [])
|
|
for cls in dataclasses:
|
|
self.models.append(self._parse_dataclass(cls))
|
|
|
|
# Extract API_MODELS (TypeScript-only request/response types)
|
|
api_models = getattr(module, "API_MODELS", [])
|
|
for cls in api_models:
|
|
self.api_models.append(self._parse_dataclass(cls))
|
|
|
|
# Extract ENUMS
|
|
enums = getattr(module, "ENUMS", [])
|
|
for enum_cls in enums:
|
|
self.enums.append(self._parse_enum(enum_cls))
|
|
|
|
# Extract GRPC_MESSAGES (optional)
|
|
grpc_messages = getattr(module, "GRPC_MESSAGES", [])
|
|
for cls in grpc_messages:
|
|
self.grpc_messages.append(self._parse_dataclass(cls))
|
|
|
|
# Extract GRPC_SERVICE (optional)
|
|
grpc_service = getattr(module, "GRPC_SERVICE", None)
|
|
if grpc_service:
|
|
self.grpc_service = GrpcServiceDefinition(
|
|
package=grpc_service.get("package", "service"),
|
|
name=grpc_service.get("name", "Service"),
|
|
methods=grpc_service.get("methods", []),
|
|
)
|
|
|
|
return self
|
|
|
|
def _import_module(self, path: Path):
|
|
"""Import a Python module from a file path."""
|
|
spec = importlib.util.spec_from_file_location("schema", path)
|
|
if spec is None or spec.loader is None:
|
|
raise ImportError(f"Could not load module from {path}")
|
|
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules["schema"] = module
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
def _parse_dataclass(self, cls: Type) -> ModelDefinition:
|
|
"""Parse a dataclass into a ModelDefinition."""
|
|
hints = get_type_hints(cls)
|
|
fields_info = {f.name: f for f in dc.fields(cls)}
|
|
|
|
fields = []
|
|
for name, type_hint in hints.items():
|
|
if name.startswith("_"):
|
|
continue
|
|
|
|
field_info = fields_info.get(name)
|
|
default = dc.MISSING
|
|
if field_info:
|
|
if field_info.default is not dc.MISSING:
|
|
default = field_info.default
|
|
elif field_info.default_factory is not dc.MISSING:
|
|
default = field_info.default_factory
|
|
|
|
# Check if optional (Union with None)
|
|
optional = self._is_optional(type_hint)
|
|
|
|
fields.append(
|
|
FieldDefinition(
|
|
name=name,
|
|
type_hint=type_hint,
|
|
default=default,
|
|
optional=optional,
|
|
)
|
|
)
|
|
|
|
return ModelDefinition(
|
|
name=cls.__name__,
|
|
fields=fields,
|
|
docstring=cls.__doc__,
|
|
)
|
|
|
|
def _parse_enum(self, enum_cls: Type[Enum]) -> EnumDefinition:
|
|
"""Parse an Enum into an EnumDefinition."""
|
|
values = [(m.name, m.value) for m in enum_cls]
|
|
return EnumDefinition(name=enum_cls.__name__, values=values)
|
|
|
|
def _is_optional(self, type_hint: Any) -> bool:
|
|
"""Check if a type hint is Optional (Union with None)."""
|
|
from typing import Union, get_args, get_origin
|
|
|
|
origin = get_origin(type_hint)
|
|
if origin is Union:
|
|
args = get_args(type_hint)
|
|
return type(None) in args
|
|
return False
|
|
|
|
|
|
def load_schema(schema_path: str | Path) -> SchemaLoader:
|
|
"""Load schema definitions from folder."""
|
|
loader = SchemaLoader(schema_path)
|
|
return loader.load()
|