169 lines
5.7 KiB
Python
169 lines
5.7 KiB
Python
"""
|
|
Protobuf Generator
|
|
|
|
Generates Protocol Buffer definitions from model definitions.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
from typing import Any, List, get_type_hints
|
|
|
|
from ..helpers import get_origin_name, unwrap_optional
|
|
from ..loader.schema import GrpcServiceDefinition, ModelDefinition
|
|
from ..types import PROTO_RESOLVERS
|
|
from .base import BaseGenerator
|
|
|
|
|
|
class ProtobufGenerator(BaseGenerator):
|
|
"""Generates Protocol Buffer definition files."""
|
|
|
|
def file_extension(self) -> str:
|
|
return ".proto"
|
|
|
|
def generate(self, models, output_path: Path) -> None:
|
|
"""Generate protobuf definitions to output_path."""
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Handle different input types
|
|
if hasattr(models, "grpc_messages"):
|
|
# SchemaLoader with gRPC definitions
|
|
content = self._generate_from_loader(models)
|
|
elif isinstance(models, tuple) and len(models) >= 3:
|
|
# (messages, service_def) tuple
|
|
content = self._generate_from_definitions(models[0], models[1])
|
|
elif isinstance(models, list):
|
|
# List of dataclasses (MPR style)
|
|
content = self._generate_from_dataclasses(models)
|
|
else:
|
|
raise ValueError(f"Unsupported input type: {type(models)}")
|
|
|
|
output_path.write_text(content)
|
|
|
|
def _generate_from_loader(self, loader) -> str:
|
|
"""Generate from SchemaLoader."""
|
|
messages = loader.grpc_messages
|
|
service = loader.grpc_service
|
|
|
|
lines = self._generate_header(
|
|
service.package if service else "service",
|
|
service.name if service else "Service",
|
|
service.methods if service else [],
|
|
)
|
|
|
|
for model_def in messages:
|
|
lines.extend(self._generate_message_from_definition(model_def))
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _generate_from_definitions(
|
|
self, messages: List[ModelDefinition], service: GrpcServiceDefinition
|
|
) -> str:
|
|
"""Generate from ModelDefinition objects."""
|
|
lines = self._generate_header(service.package, service.name, service.methods)
|
|
|
|
for model_def in messages:
|
|
lines.extend(self._generate_message_from_definition(model_def))
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _generate_from_dataclasses(self, dataclasses: List[type]) -> str:
|
|
"""Generate from Python dataclasses (MPR style)."""
|
|
lines = self._generate_header("service", "Service", [])
|
|
|
|
for cls in dataclasses:
|
|
lines.extend(self._generate_message_from_dataclass(cls))
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _generate_header(
|
|
self, package: str, service_name: str, methods: List[dict]
|
|
) -> List[str]:
|
|
"""Generate file header with service definition."""
|
|
lines = [
|
|
"// Protocol Buffer Definitions - GENERATED FILE",
|
|
"//",
|
|
"// Do not edit directly. Regenerate using modelgen.",
|
|
"",
|
|
'syntax = "proto3";',
|
|
"",
|
|
f"package {package};",
|
|
"",
|
|
]
|
|
|
|
if methods:
|
|
lines.append(f"service {service_name} {{")
|
|
for m in methods:
|
|
req = (
|
|
m["request"].__name__
|
|
if hasattr(m["request"], "__name__")
|
|
else str(m["request"])
|
|
)
|
|
resp = (
|
|
m["response"].__name__
|
|
if hasattr(m["response"], "__name__")
|
|
else str(m["response"])
|
|
)
|
|
returns = f"stream {resp}" if m.get("stream_response") else resp
|
|
lines.append(f" rpc {m['name']}({req}) returns ({returns});")
|
|
lines.extend(["}", ""])
|
|
|
|
return lines
|
|
|
|
def _generate_message_from_definition(
|
|
self, model_def: ModelDefinition
|
|
) -> List[str]:
|
|
"""Generate proto message from ModelDefinition."""
|
|
lines = [f"message {model_def.name} {{"]
|
|
|
|
if not model_def.fields:
|
|
lines.append(" // Empty")
|
|
else:
|
|
for i, field in enumerate(model_def.fields, 1):
|
|
proto_type, optional = self._resolve_type(field.type_hint)
|
|
prefix = (
|
|
"optional "
|
|
if optional and not proto_type.startswith("repeated")
|
|
else ""
|
|
)
|
|
lines.append(f" {prefix}{proto_type} {field.name} = {i};")
|
|
|
|
lines.append("}")
|
|
return lines
|
|
|
|
def _generate_message_from_dataclass(self, cls: type) -> List[str]:
|
|
"""Generate proto message from a dataclass."""
|
|
lines = [f"message {cls.__name__} {{"]
|
|
|
|
hints = get_type_hints(cls)
|
|
if not hints:
|
|
lines.append(" // Empty")
|
|
else:
|
|
for i, (name, type_hint) in enumerate(hints.items(), 1):
|
|
proto_type, optional = self._resolve_type(type_hint)
|
|
prefix = (
|
|
"optional "
|
|
if optional and not proto_type.startswith("repeated")
|
|
else ""
|
|
)
|
|
lines.append(f" {prefix}{proto_type} {name} = {i};")
|
|
|
|
lines.append("}")
|
|
return lines
|
|
|
|
def _resolve_type(self, type_hint: Any) -> tuple[str, bool]:
|
|
"""Resolve Python type to proto type. Returns (type, is_optional)."""
|
|
base, optional = unwrap_optional(type_hint)
|
|
origin = get_origin_name(base)
|
|
|
|
# Look up resolver
|
|
resolver = PROTO_RESOLVERS.get(origin) or PROTO_RESOLVERS.get(base)
|
|
|
|
if resolver:
|
|
result = resolver(base)
|
|
is_repeated = result.startswith("repeated")
|
|
return result, optional and not is_repeated
|
|
|
|
return "string", optional
|