""" Core logic for test discovery and execution. """ import unittest import time import threading import traceback import uuid from dataclasses import dataclass, field from pathlib import Path from typing import Optional from enum import Enum class TestStatus(str, Enum): PENDING = "pending" RUNNING = "running" PASSED = "passed" FAILED = "failed" ERROR = "error" SKIPPED = "skipped" @dataclass class TestInfo: """Information about a discovered test.""" id: str name: str module: str class_name: str method_name: str doc: Optional[str] = None @dataclass class TestResult: """Result of a single test execution.""" test_id: str name: str status: TestStatus duration: float = 0.0 error_message: Optional[str] = None traceback: Optional[str] = None artifacts: list[dict] = field(default_factory=list) # List of artifact metadata @dataclass class RunStatus: """Status of a test run.""" run_id: str status: str # "running", "completed", "failed" total: int = 0 completed: int = 0 passed: int = 0 failed: int = 0 errors: int = 0 skipped: int = 0 results: list[TestResult] = field(default_factory=list) started_at: Optional[float] = None finished_at: Optional[float] = None current_test: Optional[str] = None # Global storage for run statuses _runs: dict[str, RunStatus] = {} _runs_lock = threading.Lock() def discover_tests() -> list[TestInfo]: """Discover all tests in the tests directory.""" tests_dir = Path(__file__).parent / "tests" # top_level_dir must be contracts_http's parent (tools/) so that # relative imports like "from ...base" resolve to contracts_http.base top_level = Path(__file__).parent.parent loader = unittest.TestLoader() # Discover tests suite = loader.discover(str(tests_dir), pattern="test_*.py", top_level_dir=str(top_level)) tests = [] def extract_tests(suite_or_case): if isinstance(suite_or_case, unittest.TestSuite): for item in suite_or_case: extract_tests(item) elif isinstance(suite_or_case, unittest.TestCase): test_method = getattr(suite_or_case, suite_or_case._testMethodName, None) doc = test_method.__doc__ if test_method else None # Build module path relative to tests/ module_parts = suite_or_case.__class__.__module__.split(".") # Remove 'contracts_http.tests' prefix if present if len(module_parts) > 2 and module_parts[-3] == "tests": module_name = ".".join(module_parts[-2:]) else: module_name = suite_or_case.__class__.__module__ test_id = f"{module_name}.{suite_or_case.__class__.__name__}.{suite_or_case._testMethodName}" tests.append(TestInfo( id=test_id, name=suite_or_case._testMethodName, module=module_name, class_name=suite_or_case.__class__.__name__, method_name=suite_or_case._testMethodName, doc=doc.strip() if doc else None, )) extract_tests(suite) return tests def get_tests_tree() -> dict: """Get tests organized as a tree structure for the UI.""" tests = discover_tests() tree = {} for test in tests: # Parse module to get folder structure parts = test.module.split(".") folder = parts[0] if parts else "root" if folder not in tree: tree[folder] = {"modules": {}, "test_count": 0} module_name = parts[-1] if len(parts) > 1 else test.module if module_name not in tree[folder]["modules"]: tree[folder]["modules"][module_name] = {"classes": {}, "test_count": 0} if test.class_name not in tree[folder]["modules"][module_name]["classes"]: tree[folder]["modules"][module_name]["classes"][test.class_name] = {"tests": [], "test_count": 0} tree[folder]["modules"][module_name]["classes"][test.class_name]["tests"].append({ "id": test.id, "name": test.method_name, "doc": test.doc, }) tree[folder]["modules"][module_name]["classes"][test.class_name]["test_count"] += 1 tree[folder]["modules"][module_name]["test_count"] += 1 tree[folder]["test_count"] += 1 return tree class ResultCollector(unittest.TestResult): """Custom test result collector.""" def __init__(self, run_status: RunStatus): super().__init__() self.run_status = run_status self._test_start_times: dict[str, float] = {} def _get_test_id(self, test: unittest.TestCase) -> str: module_parts = test.__class__.__module__.split(".") if len(module_parts) > 2 and module_parts[-3] == "tests": module_name = ".".join(module_parts[-2:]) else: module_name = test.__class__.__module__ return f"{module_name}.{test.__class__.__name__}.{test._testMethodName}" def startTest(self, test): super().startTest(test) test_id = self._get_test_id(test) self._test_start_times[test_id] = time.time() with _runs_lock: self.run_status.current_test = test_id def stopTest(self, test): super().stopTest(test) with _runs_lock: self.run_status.current_test = None def addSuccess(self, test): super().addSuccess(test) test_id = self._get_test_id(test) duration = time.time() - self._test_start_times.get(test_id, time.time()) result = TestResult( test_id=test_id, name=test._testMethodName, status=TestStatus.PASSED, duration=duration, ) with _runs_lock: self.run_status.results.append(result) self.run_status.completed += 1 self.run_status.passed += 1 def addFailure(self, test, err): super().addFailure(test, err) test_id = self._get_test_id(test) duration = time.time() - self._test_start_times.get(test_id, time.time()) result = TestResult( test_id=test_id, name=test._testMethodName, status=TestStatus.FAILED, duration=duration, error_message=str(err[1]), traceback="".join(traceback.format_exception(*err)), ) with _runs_lock: self.run_status.results.append(result) self.run_status.completed += 1 self.run_status.failed += 1 def addError(self, test, err): super().addError(test, err) test_id = self._get_test_id(test) duration = time.time() - self._test_start_times.get(test_id, time.time()) result = TestResult( test_id=test_id, name=test._testMethodName, status=TestStatus.ERROR, duration=duration, error_message=str(err[1]), traceback="".join(traceback.format_exception(*err)), ) with _runs_lock: self.run_status.results.append(result) self.run_status.completed += 1 self.run_status.errors += 1 def addSkip(self, test, reason): super().addSkip(test, reason) test_id = self._get_test_id(test) duration = time.time() - self._test_start_times.get(test_id, time.time()) result = TestResult( test_id=test_id, name=test._testMethodName, status=TestStatus.SKIPPED, duration=duration, error_message=reason, ) with _runs_lock: self.run_status.results.append(result) self.run_status.completed += 1 self.run_status.skipped += 1 def _run_tests_thread(run_id: str, test_ids: Optional[list[str]] = None): """Run tests in a background thread.""" tests_dir = Path(__file__).parent / "tests" top_level = Path(__file__).parent.parent loader = unittest.TestLoader() # Discover all tests suite = loader.discover(str(tests_dir), pattern="test_*.py", top_level_dir=str(top_level)) # Filter to selected tests if specified if test_ids: filtered_suite = unittest.TestSuite() def filter_tests(suite_or_case): if isinstance(suite_or_case, unittest.TestSuite): for item in suite_or_case: filter_tests(item) elif isinstance(suite_or_case, unittest.TestCase): module_parts = suite_or_case.__class__.__module__.split(".") if len(module_parts) > 2 and module_parts[-3] == "tests": module_name = ".".join(module_parts[-2:]) else: module_name = suite_or_case.__class__.__module__ test_id = f"{module_name}.{suite_or_case.__class__.__name__}.{suite_or_case._testMethodName}" # Check if this test matches any of the requested IDs for requested_id in test_ids: if test_id == requested_id or test_id.startswith(requested_id + ".") or requested_id in test_id: filtered_suite.addTest(suite_or_case) break filter_tests(suite) suite = filtered_suite # Count total tests total = suite.countTestCases() with _runs_lock: _runs[run_id].total = total _runs[run_id].started_at = time.time() # Run tests with our collector collector = ResultCollector(_runs[run_id]) try: suite.run(collector) except Exception as e: with _runs_lock: _runs[run_id].status = "failed" with _runs_lock: _runs[run_id].status = "completed" _runs[run_id].finished_at = time.time() def start_test_run(test_ids: Optional[list[str]] = None) -> str: """Start a test run in the background. Returns run_id.""" run_id = str(uuid.uuid4())[:8] run_status = RunStatus( run_id=run_id, status="running", ) with _runs_lock: _runs[run_id] = run_status # Start background thread thread = threading.Thread(target=_run_tests_thread, args=(run_id, test_ids)) thread.daemon = True thread.start() return run_id def get_run_status(run_id: str) -> Optional[RunStatus]: """Get the status of a test run.""" with _runs_lock: return _runs.get(run_id) def list_runs() -> list[dict]: """List all test runs.""" with _runs_lock: return [ { "run_id": run.run_id, "status": run.status, "total": run.total, "completed": run.completed, "passed": run.passed, "failed": run.failed, } for run in _runs.values() ]