remove duplicated code
This commit is contained in:
169
modelgen/loader/schema.py
Normal file
169
modelgen/loader/schema.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Schema Loader
|
||||
|
||||
Loads Python dataclasses from a schema/ folder.
|
||||
Expects the folder to have an __init__.py that exports:
|
||||
- DATACLASSES: List of dataclass types to generate
|
||||
- ENUMS: List of Enum types to include
|
||||
- GRPC_MESSAGES: (optional) List of gRPC message types
|
||||
- GRPC_SERVICE: (optional) gRPC service definition dict
|
||||
"""
|
||||
|
||||
import dataclasses as dc
|
||||
import importlib.util
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, get_type_hints
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldDefinition:
|
||||
"""Represents a model field."""
|
||||
|
||||
name: str
|
||||
type_hint: Any
|
||||
default: Any = dc.MISSING
|
||||
optional: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDefinition:
|
||||
"""Represents a model/dataclass."""
|
||||
|
||||
name: str
|
||||
fields: List[FieldDefinition]
|
||||
docstring: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumDefinition:
|
||||
"""Represents an enum."""
|
||||
|
||||
name: str
|
||||
values: List[tuple[str, str]] # (name, value) pairs
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrpcServiceDefinition:
|
||||
"""Represents a gRPC service."""
|
||||
|
||||
package: str
|
||||
name: str
|
||||
methods: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class SchemaLoader:
|
||||
"""Loads model definitions from Python dataclasses in schema/ folder."""
|
||||
|
||||
def __init__(self, schema_path: Path):
|
||||
self.schema_path = Path(schema_path)
|
||||
self.models: List[ModelDefinition] = []
|
||||
self.enums: List[EnumDefinition] = []
|
||||
self.grpc_messages: List[ModelDefinition] = []
|
||||
self.grpc_service: Optional[GrpcServiceDefinition] = None
|
||||
|
||||
def load(self) -> "SchemaLoader":
|
||||
"""Load schema definitions from the schema folder."""
|
||||
init_path = self.schema_path / "__init__.py"
|
||||
|
||||
if not init_path.exists():
|
||||
raise FileNotFoundError(f"Schema folder must have __init__.py: {init_path}")
|
||||
|
||||
# Import the schema module
|
||||
module = self._import_module(init_path)
|
||||
|
||||
# Extract DATACLASSES
|
||||
dataclasses = getattr(module, "DATACLASSES", [])
|
||||
for cls in dataclasses:
|
||||
self.models.append(self._parse_dataclass(cls))
|
||||
|
||||
# Extract ENUMS
|
||||
enums = getattr(module, "ENUMS", [])
|
||||
for enum_cls in enums:
|
||||
self.enums.append(self._parse_enum(enum_cls))
|
||||
|
||||
# Extract GRPC_MESSAGES (optional)
|
||||
grpc_messages = getattr(module, "GRPC_MESSAGES", [])
|
||||
for cls in grpc_messages:
|
||||
self.grpc_messages.append(self._parse_dataclass(cls))
|
||||
|
||||
# Extract GRPC_SERVICE (optional)
|
||||
grpc_service = getattr(module, "GRPC_SERVICE", None)
|
||||
if grpc_service:
|
||||
self.grpc_service = GrpcServiceDefinition(
|
||||
package=grpc_service.get("package", "service"),
|
||||
name=grpc_service.get("name", "Service"),
|
||||
methods=grpc_service.get("methods", []),
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _import_module(self, path: Path):
|
||||
"""Import a Python module from a file path."""
|
||||
spec = importlib.util.spec_from_file_location("schema", path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Could not load module from {path}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["schema"] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
def _parse_dataclass(self, cls: Type) -> ModelDefinition:
|
||||
"""Parse a dataclass into a ModelDefinition."""
|
||||
hints = get_type_hints(cls)
|
||||
fields_info = {f.name: f for f in dc.fields(cls)}
|
||||
|
||||
fields = []
|
||||
for name, type_hint in hints.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
field_info = fields_info.get(name)
|
||||
default = dc.MISSING
|
||||
if field_info:
|
||||
if field_info.default is not dc.MISSING:
|
||||
default = field_info.default
|
||||
elif field_info.default_factory is not dc.MISSING:
|
||||
default = field_info.default_factory
|
||||
|
||||
# Check if optional (Union with None)
|
||||
optional = self._is_optional(type_hint)
|
||||
|
||||
fields.append(
|
||||
FieldDefinition(
|
||||
name=name,
|
||||
type_hint=type_hint,
|
||||
default=default,
|
||||
optional=optional,
|
||||
)
|
||||
)
|
||||
|
||||
return ModelDefinition(
|
||||
name=cls.__name__,
|
||||
fields=fields,
|
||||
docstring=cls.__doc__,
|
||||
)
|
||||
|
||||
def _parse_enum(self, enum_cls: Type[Enum]) -> EnumDefinition:
|
||||
"""Parse an Enum into an EnumDefinition."""
|
||||
values = [(m.name, m.value) for m in enum_cls]
|
||||
return EnumDefinition(name=enum_cls.__name__, values=values)
|
||||
|
||||
def _is_optional(self, type_hint: Any) -> bool:
|
||||
"""Check if a type hint is Optional (Union with None)."""
|
||||
from typing import Union, get_args, get_origin
|
||||
|
||||
origin = get_origin(type_hint)
|
||||
if origin is Union:
|
||||
args = get_args(type_hint)
|
||||
return type(None) in args
|
||||
return False
|
||||
|
||||
|
||||
def load_schema(schema_path: str | Path) -> SchemaLoader:
|
||||
"""Load schema definitions from folder."""
|
||||
loader = SchemaLoader(schema_path)
|
||||
return loader.load()
|
||||
Reference in New Issue
Block a user