""" Protobuf Generator Generates Protocol Buffer definitions from model definitions. """ from pathlib import Path from typing import Any, List, get_type_hints from ..helpers import get_origin_name, unwrap_optional from ..loader.schema import GrpcServiceDefinition, ModelDefinition from ..types import PROTO_RESOLVERS from .base import BaseGenerator class ProtobufGenerator(BaseGenerator): """Generates Protocol Buffer definition files.""" def file_extension(self) -> str: return ".proto" def generate(self, models, output_path: Path) -> None: """Generate protobuf definitions to output_path.""" output_path.parent.mkdir(parents=True, exist_ok=True) # Handle different input types if hasattr(models, "grpc_messages"): # SchemaLoader with gRPC definitions content = self._generate_from_loader(models) elif isinstance(models, tuple) and len(models) >= 3: # (messages, service_def) tuple content = self._generate_from_definitions(models[0], models[1]) elif isinstance(models, list): # List of dataclasses (MPR style) content = self._generate_from_dataclasses(models) else: raise ValueError(f"Unsupported input type: {type(models)}") output_path.write_text(content) def _generate_from_loader(self, loader) -> str: """Generate from SchemaLoader.""" messages = loader.grpc_messages service = loader.grpc_service lines = self._generate_header( service.package if service else "service", service.name if service else "Service", service.methods if service else [], ) for model_def in messages: lines.extend(self._generate_message_from_definition(model_def)) lines.append("") return "\n".join(lines) def _generate_from_definitions( self, messages: List[ModelDefinition], service: GrpcServiceDefinition ) -> str: """Generate from ModelDefinition objects.""" lines = self._generate_header(service.package, service.name, service.methods) for model_def in messages: lines.extend(self._generate_message_from_definition(model_def)) lines.append("") return "\n".join(lines) def _generate_from_dataclasses(self, dataclasses: List[type]) -> str: """Generate from Python dataclasses (MPR style).""" lines = self._generate_header("service", "Service", []) for cls in dataclasses: lines.extend(self._generate_message_from_dataclass(cls)) lines.append("") return "\n".join(lines) def _generate_header( self, package: str, service_name: str, methods: List[dict] ) -> List[str]: """Generate file header with service definition.""" lines = [ "// Protocol Buffer Definitions - GENERATED FILE", "//", "// Do not edit directly. Regenerate using modelgen.", "", 'syntax = "proto3";', "", f"package {package};", "", ] if methods: lines.append(f"service {service_name} {{") for m in methods: req = ( m["request"].__name__ if hasattr(m["request"], "__name__") else str(m["request"]) ) resp = ( m["response"].__name__ if hasattr(m["response"], "__name__") else str(m["response"]) ) returns = f"stream {resp}" if m.get("stream_response") else resp lines.append(f" rpc {m['name']}({req}) returns ({returns});") lines.extend(["}", ""]) return lines def _generate_message_from_definition( self, model_def: ModelDefinition ) -> List[str]: """Generate proto message from ModelDefinition.""" lines = [f"message {model_def.name} {{"] if not model_def.fields: lines.append(" // Empty") else: for i, field in enumerate(model_def.fields, 1): proto_type, optional = self._resolve_type(field.type_hint) prefix = ( "optional " if optional and not proto_type.startswith("repeated") else "" ) lines.append(f" {prefix}{proto_type} {field.name} = {i};") lines.append("}") return lines def _generate_message_from_dataclass(self, cls: type) -> List[str]: """Generate proto message from a dataclass.""" lines = [f"message {cls.__name__} {{"] hints = get_type_hints(cls) if not hints: lines.append(" // Empty") else: for i, (name, type_hint) in enumerate(hints.items(), 1): proto_type, optional = self._resolve_type(type_hint) prefix = ( "optional " if optional and not proto_type.startswith("repeated") else "" ) lines.append(f" {prefix}{proto_type} {name} = {i};") lines.append("}") return lines def _resolve_type(self, type_hint: Any) -> tuple[str, bool]: """Resolve Python type to proto type. Returns (type, is_optional).""" base, optional = unwrap_optional(type_hint) origin = get_origin_name(base) # Look up resolver resolver = PROTO_RESOLVERS.get(origin) or PROTO_RESOLVERS.get(base) if resolver: result = resolver(base) is_repeated = result.startswith("repeated") return result, optional and not is_repeated return "string", optional