import os
import re
import signal
import tempfile
import threading
from collections.abc import Callable, Iterator
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any
from test_driver.logger import AbstractLogger
from test_driver.machine import Machine, NixStartScript, retry
from test_driver.polling_condition import PollingCondition
from test_driver.vlan import VLan
SENTINEL = object()
def get_tmp_dir() -> Path:
"""Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
Raises an exception in case the retrieved temporary directory is not writeable
See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
"""
tmp_dir = Path(tempfile.gettempdir())
tmp_dir.mkdir(mode=0o700, exist_ok=True)
if not tmp_dir.is_dir():
raise NotADirectoryError(
f"The directory defined by TMPDIR, TEMP, TMP or CWD: {tmp_dir} is not a directory"
)
if not os.access(tmp_dir, os.W_OK):
raise PermissionError(
f"The directory defined by TMPDIR, TEMP, TMP, or CWD: {tmp_dir} is not writeable"
return tmp_dir
def pythonize_name(name: str) -> str:
return re.sub(r"^[^A-z_]|[^A-z0-9_]", "_", name)
class Driver:
"""A handle to the driver that sets up the environment
and runs the tests"""
tests: str
vlans: list[VLan]
machines: list[Machine]
polling_conditions: list[PollingCondition]
global_timeout: int
race_timer: threading.Timer
logger: AbstractLogger
def __init__(
self,
start_scripts: list[str],
vlans: list[int],
tests: str,
out_dir: Path,
logger: AbstractLogger,
keep_vm_state: bool = False,
global_timeout: int = 24 * 60 * 60 * 7,
):
self.tests = tests
self.out_dir = out_dir
self.global_timeout = global_timeout
self.race_timer = threading.Timer(global_timeout, self.terminate_test)
self.logger = logger
tmp_dir = get_tmp_dir()
with self.logger.nested("start all VLans"):
vlans = list(set(vlans))
self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]
def cmd(scripts: list[str]) -> Iterator[NixStartScript]:
for s in scripts:
yield NixStartScript(s)
self.polling_conditions = []
self.machines = [
Machine(
start_command=cmd,
keep_vm_state=keep_vm_state,
name=cmd.machine_name,
tmp_dir=tmp_dir,
callbacks=[self.check_polling_conditions],
out_dir=self.out_dir,
logger=self.logger,
for cmd in cmd(start_scripts)
]
def __enter__(self) -> "Driver":
return self
def __exit__(self, *_: Any) -> None:
with self.logger.nested("cleanup"):
self.race_timer.cancel()
for machine in self.machines:
try:
machine.release()
except Exception as e:
self.logger.error(f"Error during cleanup of {machine.name}: {e}")
for vlan in self.vlans:
vlan.stop()
self.logger.error(f"Error during cleanup of vlan{vlan.nr}: {e}")
def subtest(self, name: str) -> Iterator[None]:
"""Group logs under a given test name"""
with self.logger.subtest(name):
yield
self.logger.error(f'Test "{name}" failed with error: "{e}"')
raise e
def test_symbols(self) -> dict[str, Any]:
@contextmanager
def subtest(name: str) -> Iterator[None]:
return self.subtest(name)
general_symbols = dict(
start_all=self.start_all,
test_script=self.test_script,
machines=self.machines,
vlans=self.vlans,
driver=self,
log=self.logger,
os=os,
create_machine=self.create_machine,
subtest=subtest,
run_tests=self.run_tests,
join_all=self.join_all,
retry=retry,
serial_stdout_off=self.serial_stdout_off,
serial_stdout_on=self.serial_stdout_on,
polling_condition=self.polling_condition,
Machine=Machine, # for typing
machine_symbols = {pythonize_name(m.name): m for m in self.machines}
# If there's exactly one machine, make it available under the name
# "machine", even if it's not called that.
if len(self.machines) == 1:
(machine_symbols["machine"],) = self.machines
vlan_symbols = {
f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
}
print(
"additionally exposed symbols:\n "
+ ", ".join(map(lambda m: m.name, self.machines))
+ ",\n "
+ ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
+ ", ".join(list(general_symbols.keys()))
return {**general_symbols, **machine_symbols, **vlan_symbols}
def test_script(self) -> None:
"""Run the test script"""
with self.logger.nested("run the VM test script"):
symbols = self.test_symbols() # call eagerly
exec(self.tests, symbols, None)
def run_tests(self) -> None:
"""Run the test script (for non-interactive test runs)"""
self.logger.info(
f"Test will time out and terminate in {self.global_timeout} seconds"
self.race_timer.start()
self.test_script()
# TODO: Collect coverage data
if machine.is_up():
machine.execute("sync")
def start_all(self) -> None:
"""Start all machines"""
with self.logger.nested("start all VMs"):
machine.start()
def join_all(self) -> None:
"""Wait for all machines to shut down"""
with self.logger.nested("wait for all VMs to finish"):
machine.wait_for_shutdown()
def terminate_test(self) -> None:
# This will be usually running in another thread than
# the thread actually executing the test script.
with self.logger.nested("timeout reached; test terminating..."):
# As we cannot `sys.exit` from another thread
# We can at least force the main thread to get SIGTERM'ed.
# This will prevent any user who caught all the exceptions
# to swallow them and prevent itself from terminating.
os.kill(os.getpid(), signal.SIGTERM)
def create_machine(
start_command: str,
*,
name: str | None = None,
) -> Machine:
cmd = NixStartScript(start_command)
name = name or cmd.machine_name
return Machine(
name=name,
def serial_stdout_on(self) -> None:
self.logger.print_serial_logs(True)
def serial_stdout_off(self) -> None:
self.logger.print_serial_logs(False)
def check_polling_conditions(self) -> None:
for condition in self.polling_conditions:
condition.maybe_raise()
def polling_condition(
fun_: Callable | None = None,
seconds_interval: float = 2.0,
description: str | None = None,
) -> Callable[[Callable], AbstractContextManager] | AbstractContextManager:
driver = self
class Poll:
def __init__(self, fun: Callable):
self.condition = PollingCondition(
fun,
driver.logger,
seconds_interval,
description,
def __enter__(self) -> None:
driver.polling_conditions.append(self.condition)
def __exit__(self, a, b, c) -> None: # type: ignore
res = driver.polling_conditions.pop()
assert res is self.condition
def wait(self, timeout: int = 900) -> None:
def condition(last: bool) -> bool:
if last:
driver.logger.info(
f"Last chance for {self.condition.description}"
ret = self.condition.check(force=True)
if not ret and not last:
f"({self.condition.description} failure not fatal yet)"
return ret
with driver.logger.nested(f"waiting for {self.condition.description}"):
retry(condition, timeout=timeout)
if fun_ is None:
return Poll
else:
return Poll(fun_)