"""Local Lambda runner — FastAPI wrapper that invokes any `handler(event, context)` file in /app and reports AWS-equivalent metrics. Nothing in this file is touched by the lambda function itself; functions stay verbatim-uploadable to AWS. Features that are scaffolded now and "light up" later when matching improvements land in the function: - event payload pass-through (improvement #1: BUCKET/PREFIX from event) - structured JSON log capture (improvement #2: JSON logging to stdout) - EMF metric extraction (improvement #3: CloudWatch EMF embedded metrics) Until the function emits those, the corresponding output fields are empty. """ import asyncio import importlib.util import io import json import math import os import resource import subprocess import sys import time import traceback import uuid from contextlib import redirect_stderr, redirect_stdout from pathlib import Path from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field FUNCTIONS_DIR = Path(os.environ.get("FUNCTIONS_DIR", "/app/functions")) SHARED_DIR = Path(os.environ.get("SHARED_DIR", "/app/shared")) MAX_INVOCATIONS = int(os.environ.get("RUNNER_MAX_INVOCATIONS", "200")) # Make shared/ importable for any function. Mirrors AWS Lambda Layer behavior # (layer code is added to PYTHONPATH for all functions that attach it). if SHARED_DIR.exists(): _repo_root = str(SHARED_DIR.parent) if _repo_root not in sys.path: sys.path.insert(0, _repo_root) app = FastAPI(title="Lambda Local Runner") _modules: dict = {} # name -> imported module (cache; presence = warm) _invocations: list[dict] = [] # newest last; capped at MAX_INVOCATIONS class LambdaContext: """Minimal stand-in for the AWS Lambda context object. The function file doesn't use any of this today, but the shape is right for when improvements add `context.aws_request_id` to structured logs etc.""" def __init__(self, request_id: str, function_name: str, memory_mb: int, timeout_ms: int): self.aws_request_id = request_id self.function_name = function_name self.function_version = "$LATEST" self.invoked_function_arn = f"arn:aws:lambda:local:000000000000:function:{function_name}" self.memory_limit_in_mb = memory_mb self.log_group_name = f"/aws/lambda/{function_name}" self.log_stream_name = f"local/{time.strftime('%Y/%m/%d')}/[$LATEST]{request_id}" self._deadline_ms = time.monotonic() * 1000 + timeout_ms def get_remaining_time_in_millis(self) -> int: return max(0, int(self._deadline_ms - time.monotonic() * 1000)) class InvokeRequest(BaseModel): event: dict = Field(default_factory=dict) # AWS Lambda memory sizes: 128, 256, 512, 1024, 1536, 2048, 3008, 5120, 10240 memory_mb: int = 128 # AWS Lambda timeout: 1-900 seconds. Locally we record it but don't kill # the handler (matches "function works verbatim" — no signal interruption). timeout_ms: int = 30_000 class ScriptRequest(BaseModel): args: list[str] = Field(default_factory=list) @app.get("/functions") def list_functions(): """Scan FUNCTIONS_DIR for subfolders containing a handler.py with `def handler(event, context):`. Each function lives in its own folder (matches AWS Lambda's deployment-package shape).""" funcs: list[dict] = [] if not FUNCTIONS_DIR.exists(): return {"functions": [], "functions_dir": str(FUNCTIONS_DIR), "error": "directory not found"} for d in sorted(FUNCTIONS_DIR.iterdir()): if not d.is_dir() or d.name.startswith("_"): continue handler = d / "handler.py" if not handler.exists(): continue try: text = handler.read_text() except Exception: continue if "def handler(event, context)" not in text and "def handler(event,context)" not in text: continue # Discover sample events (events/*.json) so the UI can populate a dropdown. events = sorted(p.name for p in (d / "events").glob("*.json")) if (d / "events").is_dir() else [] funcs.append({"name": d.name, "events": events}) return {"functions": funcs, "functions_dir": str(FUNCTIONS_DIR)} @app.get("/functions/{name}/events/{filename}") def get_event(name: str, filename: str): """Serve a sample event file so the UI can preview/select it.""" path = FUNCTIONS_DIR / name / "events" / filename if not path.exists() or not path.is_file(): raise HTTPException(status_code=404, detail="event file not found") try: return json.loads(path.read_text()) except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"invalid JSON in {path}: {e}") @app.post("/invoke/{name}") def invoke(name: str, req: InvokeRequest): """Invoke a handler. Sync def so FastAPI runs us in a thread — that lets the function call `asyncio.run(...)` internally without nested-loop errors (which is how the current handler.py works).""" target = FUNCTIONS_DIR / name / "handler.py" if not target.exists(): raise HTTPException(status_code=404, detail=f"{target} not found") invocation_id = str(uuid.uuid4()) cold_start = name not in _modules init_duration_ms = None record: dict = { "invocation_id": invocation_id, "function": name, "timestamp": time.time(), "event": req.event, "result": None, "error": None, "stdout": "", "stderr": "", "structured_logs": [], "emf_metrics": [], "metrics": { "cold_start": cold_start, "init_duration_ms": None, "duration_ms": 0.0, "billed_duration_ms": 0, "memory_size_mb": req.memory_mb, "max_memory_used_mb": 0.0, }, } # Cold-start: import the module and time the import. This matches AWS's # "Init Duration" — time to load module-level code (imports + module-scope # statements). On warm invocations this whole block is skipped. if cold_start: spec = importlib.util.spec_from_file_location( f"functions.{name}.handler", target, ) module = importlib.util.module_from_spec(spec) t0 = time.monotonic() try: spec.loader.exec_module(module) except Exception as e: init_duration_ms = (time.monotonic() - t0) * 1000 record["error"] = _format_exception(e) record["metrics"]["init_duration_ms"] = round(init_duration_ms, 2) return _record(record) init_duration_ms = (time.monotonic() - t0) * 1000 _modules[name] = module record["metrics"]["init_duration_ms"] = round(init_duration_ms, 2) module = _modules[name] if not hasattr(module, "handler"): raise HTTPException(status_code=400, detail=f"{name}.py has no handler() function") context = LambdaContext( request_id=invocation_id, function_name=name, memory_mb=req.memory_mb, timeout_ms=req.timeout_ms, ) stdout_buf = io.StringIO() stderr_buf = io.StringIO() t_handler = time.monotonic() try: with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf): result = module.handler(req.event, context) # Defensive: if a future async handler ever returns a coroutine # (AWS doesn't support that natively, but we might), run it. if asyncio.iscoroutine(result): result = asyncio.run(result) record["result"] = result except Exception as e: record["error"] = _format_exception(e) duration_ms = (time.monotonic() - t_handler) * 1000 # ru_maxrss: kilobytes on Linux, bytes on macOS. We run on Linux in kind. rusage = resource.getrusage(resource.RUSAGE_SELF) max_memory_mb = rusage.ru_maxrss / 1024 record["stdout"] = stdout_buf.getvalue() record["stderr"] = stderr_buf.getvalue() record["structured_logs"] = _extract_json_logs(record["stdout"]) record["emf_metrics"] = _extract_emf_metrics(record["stdout"]) record["metrics"]["duration_ms"] = round(duration_ms, 2) record["metrics"]["billed_duration_ms"] = int(math.ceil(duration_ms)) record["metrics"]["max_memory_used_mb"] = round(max_memory_mb, 2) return _record(record) @app.get("/invocations") def list_invocations(limit: int = 50): """Index of past invocations, newest first. Lightweight summary only — use /invocations/{id} for the full record.""" items = [] for r in reversed(_invocations[-limit:]): items.append({ "invocation_id": r["invocation_id"], "function": r["function"], "timestamp": r["timestamp"], "cold_start": r["metrics"]["cold_start"], "duration_ms": r["metrics"]["duration_ms"], "init_duration_ms": r["metrics"]["init_duration_ms"], "max_memory_used_mb": r["metrics"]["max_memory_used_mb"], "ok": r["error"] is None, }) return {"invocations": items, "total": len(_invocations)} @app.get("/invocations/{invocation_id}") def get_invocation(invocation_id: str): for r in _invocations: if r["invocation_id"] == invocation_id: return r raise HTTPException(status_code=404, detail="invocation not found") @app.delete("/invocations") def clear_invocations(): n = len(_invocations) _invocations.clear() return {"cleared": n} @app.post("/reset") def reset_modules(): """Clear the module cache so the next invocation is cold. Useful for A/B-ing cold-start cost without restarting the FastAPI process.""" cleared = list(_modules.keys()) _modules.clear() for name in cleared: sys.modules.pop(name, None) return {"cleared": cleared} @app.get("/functions/{name}/scripts") def list_scripts(name: str): """List support scripts for a function — any .py file that isn't handler.py.""" func_dir = FUNCTIONS_DIR / name if not func_dir.is_dir(): raise HTTPException(status_code=404, detail=f"function {name!r} not found") scripts = [ p.name for p in sorted(func_dir.glob("*.py")) if p.name not in ("handler.py", "__init__.py") ] return {"scripts": scripts, "function": name} @app.post("/scripts/{fn_name}/{script_name}") def run_script(fn_name: str, script_name: str, req: ScriptRequest): """Run a support script from functions// with optional args.""" if ".." in script_name or "/" in script_name: raise HTTPException(status_code=400, detail="invalid script name") if not script_name.endswith(".py"): raise HTTPException(status_code=400, detail="only .py scripts allowed") script_path = FUNCTIONS_DIR / fn_name / script_name if not script_path.exists(): raise HTTPException(status_code=404, detail=f"{script_path} not found") cmd = [sys.executable, str(script_path)] + list(req.args) t0 = time.monotonic() try: result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) except subprocess.TimeoutExpired: return {"returncode": -1, "stdout": "", "stderr": "timed out after 300 s", "duration_ms": 300_000.0} return { "returncode": result.returncode, "stdout": result.stdout, "stderr": result.stderr, "duration_ms": round((time.monotonic() - t0) * 1000, 2), } @app.get("/health") def health(): return {"ok": True, "loaded_modules": list(_modules.keys()), "invocations": len(_invocations)} def _format_exception(e: BaseException) -> dict: return { "type": type(e).__name__, "message": str(e), "traceback": traceback.format_exc(), } def _extract_json_logs(stdout_text: str) -> list[dict]: """Parse JSON-per-line structured logs out of stdout. Fails silently — until the function emits structured logs (improvement #2), this returns [].""" logs: list[dict] = [] for line in stdout_text.splitlines(): line = line.strip() if not (line.startswith("{") and line.endswith("}")): continue try: logs.append(json.loads(line)) except json.JSONDecodeError: continue return logs def _extract_emf_metrics(stdout_text: str) -> list[dict]: """Parse CloudWatch EMF metric records out of stdout. EMF format: {"_aws": {"CloudWatchMetrics": [...], "Timestamp": ...}, "": value, ...} Fails silently — until the function emits EMF (improvement #3), returns [].""" metrics: list[dict] = [] for entry in _extract_json_logs(stdout_text): aws = entry.get("_aws") if isinstance(aws, dict) and "CloudWatchMetrics" in aws: metrics.append(entry) return metrics def _record(rec: dict) -> dict: _invocations.append(rec) if len(_invocations) > MAX_INVOCATIONS: del _invocations[: len(_invocations) - MAX_INVOCATIONS] return rec