fixes and modelgen insert
This commit is contained in:
168
tools/modelgen/generator/protobuf.py
Normal file
168
tools/modelgen/generator/protobuf.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Protobuf Generator
|
||||
|
||||
Generates Protocol Buffer definitions from model definitions.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, List, get_type_hints
|
||||
|
||||
from ..helpers import get_origin_name, unwrap_optional
|
||||
from ..loader.schema import GrpcServiceDefinition, ModelDefinition
|
||||
from ..types import PROTO_RESOLVERS
|
||||
from .base import BaseGenerator
|
||||
|
||||
|
||||
class ProtobufGenerator(BaseGenerator):
|
||||
"""Generates Protocol Buffer definition files."""
|
||||
|
||||
def file_extension(self) -> str:
|
||||
return ".proto"
|
||||
|
||||
def generate(self, models, output_path: Path) -> None:
|
||||
"""Generate protobuf definitions to output_path."""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Handle different input types
|
||||
if hasattr(models, "grpc_messages"):
|
||||
# SchemaLoader with gRPC definitions
|
||||
content = self._generate_from_loader(models)
|
||||
elif isinstance(models, tuple) and len(models) >= 3:
|
||||
# (messages, service_def) tuple
|
||||
content = self._generate_from_definitions(models[0], models[1])
|
||||
elif isinstance(models, list):
|
||||
# List of dataclasses (MPR style)
|
||||
content = self._generate_from_dataclasses(models)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(models)}")
|
||||
|
||||
output_path.write_text(content)
|
||||
|
||||
def _generate_from_loader(self, loader) -> str:
|
||||
"""Generate from SchemaLoader."""
|
||||
messages = loader.grpc_messages
|
||||
service = loader.grpc_service
|
||||
|
||||
lines = self._generate_header(
|
||||
service.package if service else "service",
|
||||
service.name if service else "Service",
|
||||
service.methods if service else [],
|
||||
)
|
||||
|
||||
for model_def in messages:
|
||||
lines.extend(self._generate_message_from_definition(model_def))
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_from_definitions(
|
||||
self, messages: List[ModelDefinition], service: GrpcServiceDefinition
|
||||
) -> str:
|
||||
"""Generate from ModelDefinition objects."""
|
||||
lines = self._generate_header(service.package, service.name, service.methods)
|
||||
|
||||
for model_def in messages:
|
||||
lines.extend(self._generate_message_from_definition(model_def))
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_from_dataclasses(self, dataclasses: List[type]) -> str:
|
||||
"""Generate from Python dataclasses (MPR style)."""
|
||||
lines = self._generate_header("service", "Service", [])
|
||||
|
||||
for cls in dataclasses:
|
||||
lines.extend(self._generate_message_from_dataclass(cls))
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_header(
|
||||
self, package: str, service_name: str, methods: List[dict]
|
||||
) -> List[str]:
|
||||
"""Generate file header with service definition."""
|
||||
lines = [
|
||||
"// Protocol Buffer Definitions - GENERATED FILE",
|
||||
"//",
|
||||
"// Do not edit directly. Regenerate using modelgen.",
|
||||
"",
|
||||
'syntax = "proto3";',
|
||||
"",
|
||||
f"package {package};",
|
||||
"",
|
||||
]
|
||||
|
||||
if methods:
|
||||
lines.append(f"service {service_name} {{")
|
||||
for m in methods:
|
||||
req = (
|
||||
m["request"].__name__
|
||||
if hasattr(m["request"], "__name__")
|
||||
else str(m["request"])
|
||||
)
|
||||
resp = (
|
||||
m["response"].__name__
|
||||
if hasattr(m["response"], "__name__")
|
||||
else str(m["response"])
|
||||
)
|
||||
returns = f"stream {resp}" if m.get("stream_response") else resp
|
||||
lines.append(f" rpc {m['name']}({req}) returns ({returns});")
|
||||
lines.extend(["}", ""])
|
||||
|
||||
return lines
|
||||
|
||||
def _generate_message_from_definition(
|
||||
self, model_def: ModelDefinition
|
||||
) -> List[str]:
|
||||
"""Generate proto message from ModelDefinition."""
|
||||
lines = [f"message {model_def.name} {{"]
|
||||
|
||||
if not model_def.fields:
|
||||
lines.append(" // Empty")
|
||||
else:
|
||||
for i, field in enumerate(model_def.fields, 1):
|
||||
proto_type, optional = self._resolve_type(field.type_hint)
|
||||
prefix = (
|
||||
"optional "
|
||||
if optional and not proto_type.startswith("repeated")
|
||||
else ""
|
||||
)
|
||||
lines.append(f" {prefix}{proto_type} {field.name} = {i};")
|
||||
|
||||
lines.append("}")
|
||||
return lines
|
||||
|
||||
def _generate_message_from_dataclass(self, cls: type) -> List[str]:
|
||||
"""Generate proto message from a dataclass."""
|
||||
lines = [f"message {cls.__name__} {{"]
|
||||
|
||||
hints = get_type_hints(cls)
|
||||
if not hints:
|
||||
lines.append(" // Empty")
|
||||
else:
|
||||
for i, (name, type_hint) in enumerate(hints.items(), 1):
|
||||
proto_type, optional = self._resolve_type(type_hint)
|
||||
prefix = (
|
||||
"optional "
|
||||
if optional and not proto_type.startswith("repeated")
|
||||
else ""
|
||||
)
|
||||
lines.append(f" {prefix}{proto_type} {name} = {i};")
|
||||
|
||||
lines.append("}")
|
||||
return lines
|
||||
|
||||
def _resolve_type(self, type_hint: Any) -> tuple[str, bool]:
|
||||
"""Resolve Python type to proto type. Returns (type, is_optional)."""
|
||||
base, optional = unwrap_optional(type_hint)
|
||||
origin = get_origin_name(base)
|
||||
|
||||
# Look up resolver
|
||||
resolver = PROTO_RESOLVERS.get(origin) or PROTO_RESOLVERS.get(base)
|
||||
|
||||
if resolver:
|
||||
result = resolver(base)
|
||||
is_repeated = result.startswith("repeated")
|
||||
return result, optional and not is_repeated
|
||||
|
||||
return "string", optional
|
||||
Reference in New Issue
Block a user