Files
soleprint/station/tools/tester/core.py
2025-12-24 05:38:37 -03:00

343 lines
11 KiB
Python

"""
Core logic for test discovery and execution.
"""
import unittest
import time
import threading
import traceback
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from enum import Enum
class TestStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
PASSED = "passed"
FAILED = "failed"
ERROR = "error"
SKIPPED = "skipped"
@dataclass
class TestInfo:
"""Information about a discovered test."""
id: str
name: str
module: str
class_name: str
method_name: str
doc: Optional[str] = None
@dataclass
class TestResult:
"""Result of a single test execution."""
test_id: str
name: str
status: TestStatus
duration: float = 0.0
error_message: Optional[str] = None
traceback: Optional[str] = None
artifacts: list[dict] = field(default_factory=list) # List of artifact metadata
@dataclass
class RunStatus:
"""Status of a test run."""
run_id: str
status: str # "running", "completed", "failed"
total: int = 0
completed: int = 0
passed: int = 0
failed: int = 0
errors: int = 0
skipped: int = 0
results: list[TestResult] = field(default_factory=list)
started_at: Optional[float] = None
finished_at: Optional[float] = None
current_test: Optional[str] = None
# Global storage for run statuses
_runs: dict[str, RunStatus] = {}
_runs_lock = threading.Lock()
def discover_tests() -> list[TestInfo]:
"""Discover all tests in the tests directory."""
tests_dir = Path(__file__).parent / "tests"
# top_level_dir must be contracts_http's parent (tools/) so that
# relative imports like "from ...base" resolve to contracts_http.base
top_level = Path(__file__).parent.parent
loader = unittest.TestLoader()
# Discover tests
suite = loader.discover(str(tests_dir), pattern="test_*.py", top_level_dir=str(top_level))
tests = []
def extract_tests(suite_or_case):
if isinstance(suite_or_case, unittest.TestSuite):
for item in suite_or_case:
extract_tests(item)
elif isinstance(suite_or_case, unittest.TestCase):
test_method = getattr(suite_or_case, suite_or_case._testMethodName, None)
doc = test_method.__doc__ if test_method else None
# Build module path relative to tests/
module_parts = suite_or_case.__class__.__module__.split(".")
# Remove 'contracts_http.tests' prefix if present
if len(module_parts) > 2 and module_parts[-3] == "tests":
module_name = ".".join(module_parts[-2:])
else:
module_name = suite_or_case.__class__.__module__
test_id = f"{module_name}.{suite_or_case.__class__.__name__}.{suite_or_case._testMethodName}"
tests.append(TestInfo(
id=test_id,
name=suite_or_case._testMethodName,
module=module_name,
class_name=suite_or_case.__class__.__name__,
method_name=suite_or_case._testMethodName,
doc=doc.strip() if doc else None,
))
extract_tests(suite)
return tests
def get_tests_tree() -> dict:
"""Get tests organized as a tree structure for the UI."""
tests = discover_tests()
tree = {}
for test in tests:
# Parse module to get folder structure
parts = test.module.split(".")
folder = parts[0] if parts else "root"
if folder not in tree:
tree[folder] = {"modules": {}, "test_count": 0}
module_name = parts[-1] if len(parts) > 1 else test.module
if module_name not in tree[folder]["modules"]:
tree[folder]["modules"][module_name] = {"classes": {}, "test_count": 0}
if test.class_name not in tree[folder]["modules"][module_name]["classes"]:
tree[folder]["modules"][module_name]["classes"][test.class_name] = {"tests": [], "test_count": 0}
tree[folder]["modules"][module_name]["classes"][test.class_name]["tests"].append({
"id": test.id,
"name": test.method_name,
"doc": test.doc,
})
tree[folder]["modules"][module_name]["classes"][test.class_name]["test_count"] += 1
tree[folder]["modules"][module_name]["test_count"] += 1
tree[folder]["test_count"] += 1
return tree
class ResultCollector(unittest.TestResult):
"""Custom test result collector."""
def __init__(self, run_status: RunStatus):
super().__init__()
self.run_status = run_status
self._test_start_times: dict[str, float] = {}
def _get_test_id(self, test: unittest.TestCase) -> str:
module_parts = test.__class__.__module__.split(".")
if len(module_parts) > 2 and module_parts[-3] == "tests":
module_name = ".".join(module_parts[-2:])
else:
module_name = test.__class__.__module__
return f"{module_name}.{test.__class__.__name__}.{test._testMethodName}"
def startTest(self, test):
super().startTest(test)
test_id = self._get_test_id(test)
self._test_start_times[test_id] = time.time()
with _runs_lock:
self.run_status.current_test = test_id
def stopTest(self, test):
super().stopTest(test)
with _runs_lock:
self.run_status.current_test = None
def addSuccess(self, test):
super().addSuccess(test)
test_id = self._get_test_id(test)
duration = time.time() - self._test_start_times.get(test_id, time.time())
result = TestResult(
test_id=test_id,
name=test._testMethodName,
status=TestStatus.PASSED,
duration=duration,
)
with _runs_lock:
self.run_status.results.append(result)
self.run_status.completed += 1
self.run_status.passed += 1
def addFailure(self, test, err):
super().addFailure(test, err)
test_id = self._get_test_id(test)
duration = time.time() - self._test_start_times.get(test_id, time.time())
result = TestResult(
test_id=test_id,
name=test._testMethodName,
status=TestStatus.FAILED,
duration=duration,
error_message=str(err[1]),
traceback="".join(traceback.format_exception(*err)),
)
with _runs_lock:
self.run_status.results.append(result)
self.run_status.completed += 1
self.run_status.failed += 1
def addError(self, test, err):
super().addError(test, err)
test_id = self._get_test_id(test)
duration = time.time() - self._test_start_times.get(test_id, time.time())
result = TestResult(
test_id=test_id,
name=test._testMethodName,
status=TestStatus.ERROR,
duration=duration,
error_message=str(err[1]),
traceback="".join(traceback.format_exception(*err)),
)
with _runs_lock:
self.run_status.results.append(result)
self.run_status.completed += 1
self.run_status.errors += 1
def addSkip(self, test, reason):
super().addSkip(test, reason)
test_id = self._get_test_id(test)
duration = time.time() - self._test_start_times.get(test_id, time.time())
result = TestResult(
test_id=test_id,
name=test._testMethodName,
status=TestStatus.SKIPPED,
duration=duration,
error_message=reason,
)
with _runs_lock:
self.run_status.results.append(result)
self.run_status.completed += 1
self.run_status.skipped += 1
def _run_tests_thread(run_id: str, test_ids: Optional[list[str]] = None):
"""Run tests in a background thread."""
tests_dir = Path(__file__).parent / "tests"
top_level = Path(__file__).parent.parent
loader = unittest.TestLoader()
# Discover all tests
suite = loader.discover(str(tests_dir), pattern="test_*.py", top_level_dir=str(top_level))
# Filter to selected tests if specified
if test_ids:
filtered_suite = unittest.TestSuite()
def filter_tests(suite_or_case):
if isinstance(suite_or_case, unittest.TestSuite):
for item in suite_or_case:
filter_tests(item)
elif isinstance(suite_or_case, unittest.TestCase):
module_parts = suite_or_case.__class__.__module__.split(".")
if len(module_parts) > 2 and module_parts[-3] == "tests":
module_name = ".".join(module_parts[-2:])
else:
module_name = suite_or_case.__class__.__module__
test_id = f"{module_name}.{suite_or_case.__class__.__name__}.{suite_or_case._testMethodName}"
# Check if this test matches any of the requested IDs
for requested_id in test_ids:
if test_id == requested_id or test_id.startswith(requested_id + ".") or requested_id in test_id:
filtered_suite.addTest(suite_or_case)
break
filter_tests(suite)
suite = filtered_suite
# Count total tests
total = suite.countTestCases()
with _runs_lock:
_runs[run_id].total = total
_runs[run_id].started_at = time.time()
# Run tests with our collector
collector = ResultCollector(_runs[run_id])
try:
suite.run(collector)
except Exception as e:
with _runs_lock:
_runs[run_id].status = "failed"
with _runs_lock:
_runs[run_id].status = "completed"
_runs[run_id].finished_at = time.time()
def start_test_run(test_ids: Optional[list[str]] = None) -> str:
"""Start a test run in the background. Returns run_id."""
run_id = str(uuid.uuid4())[:8]
run_status = RunStatus(
run_id=run_id,
status="running",
)
with _runs_lock:
_runs[run_id] = run_status
# Start background thread
thread = threading.Thread(target=_run_tests_thread, args=(run_id, test_ids))
thread.daemon = True
thread.start()
return run_id
def get_run_status(run_id: str) -> Optional[RunStatus]:
"""Get the status of a test run."""
with _runs_lock:
return _runs.get(run_id)
def list_runs() -> list[dict]:
"""List all test runs."""
with _runs_lock:
return [
{
"run_id": run.run_id,
"status": run.status,
"total": run.total,
"completed": run.completed,
"passed": run.passed,
"failed": run.failed,
}
for run in _runs.values()
]