diff --git a/bazel/labgrid/runner/runner.py b/bazel/labgrid/runner/runner.py index 3f9754c7299eb4d5d7818c6e267b4508feb79d27..f111cb3719f69938d829119b10b97727e7f7d288 100644 --- a/bazel/labgrid/runner/runner.py +++ b/bazel/labgrid/runner/runner.py @@ -3,9 +3,9 @@ from __future__ import annotations import logging from contextlib import contextmanager from dataclasses import dataclass +from io import StringIO from os import environ, linesep from pathlib import Path, PurePath -from sys import stderr, stdout from typing import ( Iterator, Mapping, @@ -13,6 +13,7 @@ from typing import ( from threading import get_native_id from uuid import getnode from subprocess import CalledProcessError +from sys import stdout, stderr from bazel.labgrid.strategy import State, transition from labgrid import Environment @@ -112,18 +113,19 @@ class Runner: if not download.optional and code == 0: raise e - def run(self, cmd: str, env: Mapping[str, str] = {}, check=False) -> int: + def run( + self, + cmd: str, + env: Mapping[str, str] = {}, + stdout=stdout, + stderr=stderr, + check=False, + ) -> int: """Run a command on the device with given environment variables.""" cmd = ( f"cd {self._exec_root} && {self._tools.env(self._default_env() | env, cmd)}" ) - out, err, code = self._run(cmd, check=check) - for line in out: - stdout.write(f"{line}{linesep}") - for line in err: - stderr.write(f"{line}{linesep}") - - return (code, linesep.join(out)) + return self._run(cmd, stdout=stdout, stderr=stderr, check=check) def __enter__(self): # Get a unique temporary directory, avoiding concurrent instances @@ -158,20 +160,30 @@ class Runner: return base / path return path - def _run(self, cmd, check=True): + def _run(self, cmd, stdout=None, stderr=None, check=True): # We should let Bazel kill the process with the configured timeout if running a test timeout = None if environ.get("BAZEL_TEST") == "1" else 30 out, err, code = self._shell.run(cmd, timeout=timeout) + if check and code: - raise CalledProcessError( - cmd=cmd, returncode=code, output=stdout, stderr=err - ) + raise CalledProcessError(cmd=cmd, returncode=code, output=out, stderr=err) + + if stdout: + for line in out: + stdout.write(f"{line}{linesep}") + if stderr: + for line in err: + stderr.write(f"{line}{linesep}") - return out, err, code + return code def _read_path(self, cmd): - out, _, _ = self._run(cmd) - return PurePath(linesep.join(out).rstrip()) + try: + output = StringIO() + self._run(cmd, stdout=output) + return PurePath(output.getvalue().rstrip()) + finally: + output.close() @staticmethod def _default_env(): diff --git a/examples/custom-runners/archive-transfer/run.py b/examples/custom-runners/archive-transfer/run.py index 890cebf05c0046bc159afbc9d416174798e02548..e5c8565a671e3824c34f5ea1431478ee8a956884 100644 --- a/examples/custom-runners/archive-transfer/run.py +++ b/examples/custom-runners/archive-transfer/run.py @@ -11,12 +11,14 @@ def arguments(): parser = argparse.ArgumentParser() parser.add_argument("program", type=Path) parser.add_argument("arguments", nargs=argparse.REMAINDER) - parser.add_argument("--put", nargs="+", dest="puts", type=Path) + parser.add_argument("--put", nargs="+", dest="puts", type=Path, default=[]) parser.add_argument("--out", type=Path) return parser.parse_args() -def main(): +def main() -> int: + exit_code = -1 + args = arguments() runfiles = Runfiles.Create() @@ -34,14 +36,13 @@ def main(): r.put(uploads) r.run(f"./{unzip.remote} {archive.remote}", check=True) - code, stdout = r.run( - f"./{args.program.name} {''.join(args.arguments)}", check=True - ) - with args.out.open("w") as f: - print(stdout, file=f) + exit_code = r.run( + f"./{args.program.name} {''.join(args.arguments)}", stdout=f, check=True + ) - return code + return exit_code -main() +if __name__ == "__main__": + exit(main()) diff --git a/labgrid/run/run.py b/labgrid/run/run.py index 41c2817963f15534f8f4e4154543daa508ace6f8..c5b242f0a4921fd8d3fd8b7f8f803527f17bcfd8 100644 --- a/labgrid/run/run.py +++ b/labgrid/run/run.py @@ -24,7 +24,7 @@ def run( cmd = join((f"./{program_remote}", *arguments)) r.put(uploads) - code, _ = r.run(cmd, env) + code = r.run(cmd, env) r.get(downloads, code) return code