83 lines
2.5 KiB
Python
83 lines
2.5 KiB
Python
"""Test base — two execution modes from one test suite.
|
|
|
|
Modes (set via CONTRACT_TEST_MODE env var):
|
|
|
|
inprocess (default) — httpx.AsyncClient with ASGI transport, full FastAPI
|
|
stack in-process. Fast, no server needed.
|
|
|
|
live — httpx.AsyncClient against CONTRACT_TEST_URL.
|
|
Tests against a running deployment (Kind, EC2, etc).
|
|
|
|
Usage:
|
|
# In-process (default)
|
|
pytest tests/
|
|
|
|
# Against live server
|
|
CONTRACT_TEST_MODE=live CONTRACT_TEST_URL=http://unt.local.ar pytest tests/
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
|
|
import httpx
|
|
import pytest_asyncio
|
|
|
|
|
|
def get_mode() -> str:
|
|
return os.getenv("CONTRACT_TEST_MODE", "inprocess")
|
|
|
|
|
|
def get_base_url() -> str:
|
|
return os.getenv("CONTRACT_TEST_URL", "http://localhost:8040")
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client() -> httpx.AsyncClient:
|
|
"""Async HTTP client for API tests."""
|
|
if get_mode() == "live":
|
|
async with httpx.AsyncClient(base_url=get_base_url(), timeout=30.0) as c:
|
|
yield c
|
|
else:
|
|
from httpx import ASGITransport
|
|
from api.main import app
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
yield c
|
|
|
|
|
|
# Alias for agent tests
|
|
agent_client = client
|
|
|
|
|
|
class ContractHelpers:
|
|
"""Reusable assertion and utility methods."""
|
|
|
|
@staticmethod
|
|
def assert_status(response: httpx.Response, expected: int):
|
|
assert response.status_code == expected, (
|
|
f"Expected {expected}, got {response.status_code}: {response.text[:200]}"
|
|
)
|
|
|
|
@staticmethod
|
|
def assert_has_fields(data: dict, *fields: str):
|
|
for f in fields:
|
|
assert f in data, f"Missing field: {f}. Keys: {list(data.keys())}"
|
|
|
|
@staticmethod
|
|
def assert_is_list(data, min_length: int = 0):
|
|
assert isinstance(data, list), f"Expected list, got {type(data)}"
|
|
assert len(data) >= min_length, f"Expected >= {min_length} items, got {len(data)}"
|
|
|
|
@staticmethod
|
|
async def poll_run(client: httpx.AsyncClient, run_id: str, timeout: int = 60) -> dict:
|
|
"""Poll an agent run until completion."""
|
|
deadline = asyncio.get_event_loop().time() + timeout
|
|
while asyncio.get_event_loop().time() < deadline:
|
|
res = await client.get(f"/agents/runs/{run_id}")
|
|
data = res.json()
|
|
if data.get("status") in ("completed", "error"):
|
|
return data
|
|
await asyncio.sleep(1)
|
|
raise TimeoutError(f"Run {run_id} did not complete in {timeout}s")
|