import atexit
import codecs
import os
import sys
import time
import unicodedata
from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import ExitStack, contextmanager
from pathlib import Path
from queue import Empty, Queue
from typing import Any
from xml.sax.saxutils import XMLGenerator
from xml.sax.xmlreader import AttributesImpl
from colorama import Fore, Style
from junit_xml import TestCase, TestSuite
class AbstractLogger(ABC):
@abstractmethod
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
pass
@contextmanager
def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
def info(self, *args, **kwargs) -> None: # type: ignore
def warning(self, *args, **kwargs) -> None: # type: ignore
def error(self, *args, **kwargs) -> None: # type: ignore
def log_serial(self, message: str, machine: str) -> None:
def print_serial_logs(self, enable: bool) -> None:
class JunitXMLLogger(AbstractLogger):
class TestCaseState:
def __init__(self) -> None:
self.stdout = ""
self.stderr = ""
self.failure = False
def __init__(self, outfile: Path) -> None:
self.tests: dict[str, JunitXMLLogger.TestCaseState] = {
"main": self.TestCaseState()
}
self.currentSubtest = "main"
self.outfile: Path = outfile
self._print_serial_logs = True
atexit.register(self.close)
self.tests[self.currentSubtest].stdout += message + os.linesep
old_test = self.currentSubtest
self.tests.setdefault(name, self.TestCaseState())
self.currentSubtest = name
yield
self.currentSubtest = old_test
self.log(message)
self.tests[self.currentSubtest].stdout += args[0] + os.linesep
self.tests[self.currentSubtest].stderr += args[0] + os.linesep
self.tests[self.currentSubtest].failure = True
if not self._print_serial_logs:
return
self.log(f"{machine} # {message}")
self._print_serial_logs = enable
def close(self) -> None:
with open(self.outfile, "w") as f:
test_cases = []
for name, test_case_state in self.tests.items():
tc = TestCase(
name,
stdout=test_case_state.stdout,
stderr=test_case_state.stderr,
)
if test_case_state.failure:
tc.add_failure_info("test case failed")
test_cases.append(tc)
ts = TestSuite("NixOS integration test", test_cases)
f.write(TestSuite.to_xml_string([ts]))
class CompositeLogger(AbstractLogger):
def __init__(self, logger_list: list[AbstractLogger]) -> None:
self.logger_list = logger_list
def add_logger(self, logger: AbstractLogger) -> None:
self.logger_list.append(logger)
for logger in self.logger_list:
logger.log(message, attributes)
with ExitStack() as stack:
stack.enter_context(logger.subtest(name, attributes))
stack.enter_context(logger.nested(message, attributes))
logger.info(*args, **kwargs)
logger.warning(*args, **kwargs)
logger.error(*args, **kwargs)
sys.exit(1)
logger.print_serial_logs(enable)
logger.log_serial(message, machine)
class TerminalLogger(AbstractLogger):
def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str:
if "machine" in attributes:
return f"{attributes['machine']}: {message}"
return message
@staticmethod
def _eprint(*args: object, **kwargs: Any) -> None:
print(*args, file=sys.stderr, **kwargs)
self._eprint(self.maybe_prefix(message, attributes))
with self.nested("subtest: " + name, attributes):
self._eprint(
self.maybe_prefix(
Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
tic = time.time()
toc = time.time()
self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
self.log(*args, **kwargs)
self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
class XMLLogger(AbstractLogger):
def __init__(self, outfile: str) -> None:
self.logfile_handle = codecs.open(outfile, "wb")
self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
self.queue: Queue[dict[str, str]] = Queue()
self.xml.startDocument()
self.xml.startElement("logfile", attrs=AttributesImpl({}))
self.xml.endElement("logfile")
self.xml.endDocument()
self.logfile_handle.close()
def sanitise(self, message: str) -> str:
return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
def log_line(self, message: str, attributes: dict[str, str]) -> None:
self.xml.startElement("line", attrs=AttributesImpl(attributes))
self.xml.characters(message)
self.xml.endElement("line")
self.drain_log_queue()
self.log_line(message, attributes)
self.enqueue({"msg": message, "machine": machine, "type": "serial"})
def enqueue(self, item: dict[str, str]) -> None:
self.queue.put(item)
def drain_log_queue(self) -> None:
try:
while True:
item = self.queue.get_nowait()
msg = self.sanitise(item["msg"])
del item["msg"]
self.log_line(msg, item)
except Empty:
self.xml.startElement("nest", attrs=AttributesImpl({}))
self.xml.startElement("head", attrs=AttributesImpl(attributes))
self.xml.endElement("head")
self.xml.endElement("nest")