diff --git a/doc/conf.py b/doc/conf.py index 21f36231a4fbb21fc455131ca206dd4b6fdda577..db256dfbc4eb4522698b3eff50b3e8ee787c020d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -12,12 +12,14 @@ # All configuration values have a default; values that are commented out # serve to show the default. +import itertools import logging import os import re import subprocess import sys import inspect +import tempfile from docutils import nodes from sphinx.util.docfields import TypedField @@ -398,4 +400,39 @@ def setup(app): app.connect('autodoc-process-docstring', autodoc_process_test_method) app.connect('autodoc-process-docstring', autodoc_process_analysis_events) -# vim :set tabstop=4 shiftwidth=4 textwidth=80 expandtab + +def get_test_id_list(): + import lisa.tests.kernel as test_package + with tempfile.NamedTemporaryFile('wt') as conf: + # Create a dummy target configuration, so that exekall can build all + # expressions starting from that + conf.write('target-conf:\n') + conf.flush() + + rst_list = [] + for path in test_package.__path__: + rst_list.extend(subprocess.check_output(( + 'exekall', 'run', path, + '--conf', conf.name, + '--rst-list' + ), stderr=subprocess.STDOUT).decode('utf-8').splitlines() + ) + return rst_list + +def create_test_list_file(path): + try: + content = '\n'.join( + '* {}'.format(test_id) + for test_id in get_test_id_list() + ) + except FileNotFoundError: + content = 'Please install exekall in order to generate the list of tests.' + print('WARNING: could not generate the list of test without exekall', file=sys.stderr) + + with open(path, 'wt') as f: + f.write(content + '\n') + +create_test_list_file('test_list.rst') + + +# vim :set tabstop=4 shiftwidth=4 textwidth=80 expandtab: diff --git a/doc/kernel_tests.rst b/doc/kernel_tests.rst index ed02409388f5e6c0f94be1d7725166d2770e7f4a..5f39cfe583268e94efcdb8f67604cf179826c46c 100644 --- a/doc/kernel_tests.rst +++ b/doc/kernel_tests.rst @@ -21,6 +21,16 @@ In our case, the data usually consists of `Ftrace `_ traces that we then postprocess using :mod:`trappy`. +Available tests +=============== + +The following tests are available. They can be used as: + * direct execution using ``lisa-test`` command (``LISA shell``) and ``exekall`` + * the individual classes/methods they are composed of can be used in custom + scripts/jupyter notebooks (see ipynb/tests/synthetics_example.ipynb) + +.. include:: test_list.rst + Writing tests ============= @@ -59,8 +69,11 @@ execution of `rt-app `_ workloads. It is very useful for scheduler-related tests, as it makes it easy to create tasks with a pre-determined utilization. +API +=== + Base API -======== +++++++++ .. automodule:: lisa.tests.kernel.test_bundle :members: @@ -68,9 +81,6 @@ Base API .. TODO:: Make those imports more generic -Implemented tests -================= - Scheduler tests +++++++++++++++ @@ -107,3 +117,4 @@ Cpufreq tests .. automodule:: lisa.tests.kernel.cpufreq.sanity :members: + diff --git a/lisa/conf.py b/lisa/conf.py index 0084b6e7d46f7b6d6bbdb55e4135874f026cfd2c..1419159ff81adbb37ea4722c37109c99fd8d8d9c 100644 --- a/lisa/conf.py +++ b/lisa/conf.py @@ -438,7 +438,7 @@ class MultiSrcConfABC(Serializable, abc.ABC, metaclass=MultiSrcConfMeta): mapping = cls._from_path(path, fmt='yaml') assert isinstance(mapping, Mapping) - data = mapping[toplevel_key] + data = mapping[toplevel_key] or {} # "unwrap" an extra layer of toplevel key, to play well with !include if len(data) == 1 and toplevel_key in data.keys(): data = data[toplevel_key] diff --git a/lisa/exekall_customize.py b/lisa/exekall_customize.py index 31f01be440004635526b87b371fde74c4a0c0dd6..ea2de88185df6aaa7dc400af3ad6aa1b22634c89 100644 --- a/lisa/exekall_customize.py +++ b/lisa/exekall_customize.py @@ -34,17 +34,20 @@ from lisa.conf import MultiSrcConf from lisa.tests.kernel.test_bundle import TestBundle, Result, ResultBundle, CannotCreateError from lisa.tests.kernel.scheduler.load_tracking import FreqInvarianceItem -from exekall.utils import info, get_name, get_mro -from exekall.engine import ExprData, Consumer, PrebuiltOperator, NoValue, StorageDB +from exekall.utils import info, get_name, get_mro, NoValue +from exekall.engine import ExprData, Consumer, PrebuiltOperator, ValueDB from exekall.customization import AdaptorBase -class ExekallArtifactPath(ArtifactPath): +class NonReusable: + pass + +class ExekallArtifactPath(ArtifactPath, NonReusable): @classmethod def from_expr_data(cls, data:ExprData, consumer:Consumer) -> 'ExekallArtifactPath': """ Factory used when running under `exekall` """ - artifact_dir = Path(data['testcase_artifact_dir']).resolve() + artifact_dir = Path(data['expr_artifact_dir']).resolve() consumer_name = get_name(consumer) # Find a non-used directory @@ -69,13 +72,11 @@ class LISAAdaptor(AdaptorBase): name = 'LISA' def get_non_reusable_type_set(self): - return { - ExekallArtifactPath, - } + return {NonReusable} - def get_prebuilt_list(self): + def get_prebuilt_set(self): non_reusable_type_set = self.get_non_reusable_type_set() - op_list = [] + op_set = set() # Try to build as many configurations instances from all the files we # are given @@ -107,37 +108,36 @@ class LISAAdaptor(AdaptorBase): for conf_src, conf_path in conf_and_path_list[1:]: conf.add_src(conf_path, conf_src, fallback=True) - op_list.append( - PrebuiltOperator(conf_cls, [conf], + op_set.add(PrebuiltOperator( + conf_cls, [conf], non_reusable_type_set=non_reusable_type_set )) # Inject serialized objects as root operators for path in self.args.inject: obj = Serializable.from_path(path) - op_list.append( - PrebuiltOperator(type(obj), [obj], + op_set.add(PrebuiltOperator(type(obj), [obj], non_reusable_type_set=non_reusable_type_set )) - return op_list + return op_set - def get_hidden_callable_set(self, op_map): - hidden_callable_set = set() - for produced, op_set in op_map.items(): - if issubclass(produced, HideExekallID): - hidden_callable_set.update(op.callable_ for op in op_set) - - self.hidden_callable_set = hidden_callable_set - return hidden_callable_set + def get_hidden_op_set(self, op_set): + hidden_op_set = { + op for op in op_set + if issubclass(op.value_type, HideExekallID) + } + self.hidden_op_set = hidden_op_set + return hidden_op_set @staticmethod def register_cli_param(parser): parser.add_argument('--conf', action='append', default=[], - help="Configuration file") + help="LISA configuration file. If multiple configurations of a given type are found, they are merged (last one can override keys in previous ones)") parser.add_argument('--inject', action='append', + metavar='SERIALIZED_OBJECT_PATH', default=[], help="Serialized object to inject when building expressions") @@ -147,14 +147,15 @@ class LISAAdaptor(AdaptorBase): @classmethod def load_db(cls, db_path, *args, **kwargs): + db = super().load_db(db_path, *args, **kwargs) + # This will relocate ArtifactPath instances to the new absolute path of # the results folder, in case it has been moved to another place artifact_dir = Path(db_path).parent.resolve() - db = StorageDB.from_path(db_path, *args, **kwargs) # Relocate ArtifactPath embeded in objects so they will always # contain an absolute path that adapts to the local filesystem - for serial in db.obj_store.get_all(): + for serial in db.get_all(): val = serial.value try: dct = val.__dict__ @@ -162,33 +163,45 @@ class LISAAdaptor(AdaptorBase): continue for attr, attr_val in dct.items(): if isinstance(attr_val, ArtifactPath): - setattr(val, attr, - attr_val.with_root(artifact_dir) - ) + new_path = attr_val.with_root(artifact_dir) + # Only update paths to existing files, otherwise assume it + # was pointing outside the artifact_dir and therefore + # should not be fixed up + if os.path.exists(new_path): + setattr(val, attr, new_path) return db def finalize_expr(self, expr): - testcase_artifact_dir = expr.data['testcase_artifact_dir'] + expr_artifact_dir = expr.data['expr_artifact_dir'] artifact_dir = expr.data['artifact_dir'] for expr_val in expr.get_all_vals(): - self._finalize_expr_val(expr_val, artifact_dir, testcase_artifact_dir) + self._finalize_expr_val(expr_val, artifact_dir, expr_artifact_dir) - def _finalize_expr_val(self, expr_val, artifact_dir, testcase_artifact_dir): + def _finalize_expr_val(self, expr_val, artifact_dir, expr_artifact_dir): val = expr_val.value + def needs_rewriting(val): + # Only rewrite ArtifactPath path values + if not isinstance(val, ArtifactPath): + return False + # And only if they are a subfolder of artifact_dir. Otherwise, they + # are something pointing outside of the artifact area, which we + # cannot handle. + return artifact_dir.resolve() in Path(val).resolve().parents + # Add symlinks to artifact folders for ExprValue that were used in the # ExprValue graph, but were initially computed for another Expression - if isinstance(val, ArtifactPath): + if needs_rewriting(val): val = Path(val) - is_subfolder = (testcase_artifact_dir.resolve() in val.resolve().parents) + is_subfolder = (expr_artifact_dir.resolve() in val.resolve().parents) # The folder is reachable from our ExprValue, but is not a - # subfolder of the testcase_artifact_dir, so we want to get a + # subfolder of the expr_artifact_dir, so we want to get a # symlink to it if not is_subfolder: # We get the name of the callable callable_folder = val.parts[-2] - folder = testcase_artifact_dir/callable_folder + folder = expr_artifact_dir/callable_folder # We build a relative path back in the hierarchy to the root of # all artifacts @@ -208,8 +221,8 @@ class LISAAdaptor(AdaptorBase): symlink.symlink_to(target, target_is_directory=True) - for param, param_expr_val in expr_val.param_expr_val_map.items(): - self._finalize_expr_val(param_expr_val, artifact_dir, testcase_artifact_dir) + for param, param_expr_val in expr_val.param_map.items(): + self._finalize_expr_val(param_expr_val, artifact_dir, expr_artifact_dir) @classmethod def get_tags(cls, value): @@ -230,20 +243,25 @@ class LISAAdaptor(AdaptorBase): return tags - def process_results(self, result_map): - super().process_results(result_map) + def get_summary(self, result_map): + summary = super().get_summary(result_map) # The goal is to implement something that is roughly compatible with: # https://github.com/jenkinsci/xunit-plugin/blob/master/src/main/resources/org/jenkinsci/plugins/xunit/types/model/xsd/junit-10.xsd # This way, Jenkins should be able to read it, and other tools as well xunit_path = self.args.artifact_dir.joinpath('xunit.xml') - et_root = self.create_xunit(result_map, self.hidden_callable_set) + hidden_callable_set = { + op.callable_ + for op in self.hidden_op_set + } + et_root = self._create_xunit(result_map, hidden_callable_set) et_tree = ET.ElementTree(et_root) info('Writing xUnit file at: ' + str(xunit_path)) et_tree.write(str(xunit_path)) + return summary - def create_xunit(self, result_map, hidden_callable_set): + def _create_xunit(self, result_map, hidden_callable_set): et_testsuites = ET.Element('testsuites') testcase_list = list(result_map.keys()) @@ -270,11 +288,11 @@ class LISAAdaptor(AdaptorBase): # Get the set of UUIDs of all TestBundle instances that were # involved in the testcase. - def bundle_predicate(expr_val, param): + def bundle_predicate(expr_val): return issubclass(expr_val.expr.op.value_type, TestBundle) bundle_uuid_set = { - expr_val.value_uuid - for expr_val in expr_val.get_parent_expr_vals(bundle_predicate) + expr_val.uuid + for expr_val in expr_val.get_by_predicate(bundle_predicate) } bundle_uuid_set.discard(None) @@ -292,7 +310,7 @@ class LISAAdaptor(AdaptorBase): )) testsuite_counters['tests'] += 1 - for failed_expr_val in expr_val.get_failed_expr_vals(): + for failed_expr_val in expr_val.get_excep(): excep = failed_expr_val.excep # When one critical object cannot be created, we assume # the test was skipped. @@ -307,7 +325,7 @@ class LISAAdaptor(AdaptorBase): msg = ''.join(traceback.format_exception(type(excep), excep, excep.__traceback__)) type_ = type(excep) - append_result_tag(et_testcase, result, type_, short_msg, msg) + _append_result_tag(et_testcase, result, type_, short_msg, msg) value = expr_val.value if isinstance(value, ResultBundle): @@ -316,7 +334,7 @@ class LISAAdaptor(AdaptorBase): msg = str(value) type_ = type(value) - append_result_tag(et_testcase, result, type_, short_msg, msg) + _append_result_tag(et_testcase, result, type_, short_msg, msg) if value.result is Result.FAILED: testsuite_counters['failures'] += 1 @@ -326,8 +344,10 @@ class LISAAdaptor(AdaptorBase): return et_testsuites +# Expose it as a module-level name +load_db = LISAAdaptor.load_db -def append_result_tag(et_testcase, result, type_, short_msg, msg): +def _append_result_tag(et_testcase, result, type_, short_msg, msg): et_result = ET.SubElement(et_testcase, result, dict( type=get_name(type_, full_qual=True), type_bases=','.join( diff --git a/lisa/tests/kernel/scheduler/eas_behaviour.py b/lisa/tests/kernel/scheduler/eas_behaviour.py index b8969839dbc3861956392281f15e6d1e5ec0170d..e13304c58d4d97b4867ca614f6605379281431c1 100644 --- a/lisa/tests/kernel/scheduler/eas_behaviour.py +++ b/lisa/tests/kernel/scheduler/eas_behaviour.py @@ -33,6 +33,7 @@ from lisa.env import TestEnv from lisa.utils import ArtifactPath from lisa.energy_model import EnergyModel from lisa.trace import requires_events +from lisa.perf_analysis import PerfAnalysis class EASBehaviour(RTATestBundle, abc.ABC): """ @@ -372,6 +373,41 @@ class EASBehaviour(RTATestBundle, abc.ABC): res.add_metric("energy threshold", threshold, 'bogo-joules') return res + def test_slack(self, negative_slack_allowed_pct=15) -> ResultBundle: + """ + Assert that the RTApp workload was given enough performance + + :param negative_slack_allowed_pct: Allowed percentage of RT-app task + activations with negative slack. + :type negative_slack_allowed_pct: int + + Use :class:`lisa.perf_analysis.PerfAnalysis` to find instances where the RT-App workload + wasn't able to complete its activations (i.e. its reported "slack" + was negative). Assert that this happened less than + ``negative_slack_allowed_pct`` percent of the time. + """ + pa = PerfAnalysis(self.res_dir) + + slacks = {} + + # Data is only collected for rt-app tasks, so it's safe to iterate over + # all of them + passed = True + for task in pa.tasks(): + slack = pa.df(task)["Slack"] + + bad_activations_pct = len(slack[slack < 0]) * 100 / len(slack) + if bad_activations_pct > negative_slack_allowed_pct: + passed = False + + slacks[task] = bad_activations_pct + + res = ResultBundle.from_bool(passed) + + for task, slack in slacks.items(): + res.add_metric("{} slack".format(task), slack, '%') + + return res @classmethod def unscaled_utilization(cls, capacity, utilization_pct): diff --git a/lisa/tests/kernel/scheduler/load_tracking.py b/lisa/tests/kernel/scheduler/load_tracking.py index a5a95ed57a542a89131735a825c992ca78542933..f3c48ad3245c021e03b493e83af2638469ade03b 100644 --- a/lisa/tests/kernel/scheduler/load_tracking.py +++ b/lisa/tests/kernel/scheduler/load_tracking.py @@ -470,16 +470,6 @@ class FreqInvariance(TestBundle, LoadTrackingHelpers): return test_item.test_task_load_avg(allowed_error_pct=allowed_error_pct) return self._test_all_freq(item_test) - def test_slack(self, negative_slack_allowed_pct=15): - """ - Aggregated version of - :meth:`lisa.tests.kernel.test_bundle.RTATestBundle.test_slack` - """ - def item_test(test_item): - return test_item.test_slack( - negative_slack_allowed_pct=negative_slack_allowed_pct) - return self._test_all_freq(item_test) - def _test_all_freq(self, item_test): """ Apply the `test_item` function on all instances of diff --git a/lisa/tests/kernel/test_bundle.py b/lisa/tests/kernel/test_bundle.py index d42b73b78b6addd62ae4909e40ad99c55c98cbb0..a1de00d6276d57aff16362d61f6478abcffbcc78 100644 --- a/lisa/tests/kernel/test_bundle.py +++ b/lisa/tests/kernel/test_bundle.py @@ -27,7 +27,6 @@ from devlib.target import KernelVersion from lisa.trace import Trace from lisa.wlgen.rta import RTA -from lisa.perf_analysis import PerfAnalysis from lisa.utils import Serializable, memoized, ArtifactPath from lisa.env import TestEnv @@ -412,40 +411,4 @@ class RTATestBundle(TestBundle, abc.ABC): return cls(res_dir, te.plat_info, rtapp_profile) - def test_slack(self, negative_slack_allowed_pct=15) -> ResultBundle: - """ - Assert that the RTApp workload was given enough performance - - :param negative_slack_allowed_pct: Allowed percentage of RT-app task - activations with negative slack. - :type negative_slack_allowed_pct: int - - Use :class:`lisa.perf_analysis.PerfAnalysis` to find instances where the RT-App workload - wasn't able to complete its activations (i.e. its reported "slack" - was negative). Assert that this happened less than - ``negative_slack_allowed_pct`` percent of the time. - """ - pa = PerfAnalysis(self.res_dir) - - slacks = {} - - # Data is only collected for rt-app tasks, so it's safe to iterate over - # all of them - passed = True - for task in pa.tasks(): - slack = pa.df(task)["Slack"] - - bad_activations_pct = len(slack[slack < 0]) * 100 / len(slack) - if bad_activations_pct > negative_slack_allowed_pct: - passed = False - - slacks[task] = bad_activations_pct - - res = ResultBundle.from_bool(passed) - - for task, slack in slacks.items(): - res.add_metric("{} slack".format(task), slack, '%') - - return res - # vim :set tabstop=4 shiftwidth=4 textwidth=80 expandtab diff --git a/shell/README.txt b/shell/README.txt index 6571d75c5b278603ba17a3711be9d5ea64672c45..5765cce3ff2a557e6a354d9745b1fbc68019b17f 100644 --- a/shell/README.txt +++ b/shell/README.txt @@ -33,5 +33,10 @@ lisa-report - Pretty format results of last test .:: Test commands -------------------------------------- -lisa-test - Run tests and assert behaviours +lisa-test - Run tests and assert behaviours. + This is just a wrapper around exekall that selects all tests + modules and use positional arguments as --select patterns. Also + the default configuration file ($LISA_HOME/target_conf.yml) will + be used if available (but this can be extended with + user-supplied --conf). diff --git a/shell/lisa_shell b/shell/lisa_shell index 4efeef743b7a2d91726b8e6281d27483ae4cdb1c..ac95a4680ffb02daf3a894e67e85e6f7737651de 100755 --- a/shell/lisa_shell +++ b/shell/lisa_shell @@ -389,11 +389,29 @@ echo # LISA Tests utility functions ################################################################################ +export LISA_CONF="$LISA_HOME/target_conf.yml" export LISA_RESULT_ROOT=$LISA_HOME/results export EXEKALL_ARTIFACT_ROOT=${EXEKALL_ARTIFACT_ROOT:-$LISA_RESULT_ROOT} function lisa-test { - exekall "$@" + # Add --conf target_conf.yml if the file exists + if [[ -e "$LISA_CONF" ]]; then + local conf_opt=('--conf' "$LISA_CONF") + else + local conf_opt=() + fi + + local cmd=( + exekall run "$LISA_HOME/lisa/tests/kernel/" \ + "${conf_opt[@]}" \ + --select-multiple "$@" + ) + + # Show the command before running, so --help makes more sense + echo "${cmd[@]}" + echo + + "${cmd[@]}" } ################################################################################ diff --git a/target_conf.yml b/target_conf.yml index 647a2425d6e0148faaf5e92fe5e6e035ebfa121d..32f8b289435d5cad83b4e223378c8db4d78e7d52 100644 --- a/target_conf.yml +++ b/target_conf.yml @@ -12,7 +12,7 @@ target-conf: # - linux : accessed via SSH connection # - android : accessed via ADB connection # - host : run on the local host - kind : android + # kind : android # Board # Optional board name used for better prettier logs @@ -25,7 +25,7 @@ target-conf: # device: 00b1346f0878ccb1 # Login username (has to be sudo enabled) - username: root + # username: root # Login credentials # You can specify either a password or keyfile @@ -70,7 +70,7 @@ target-conf: platform-info: # Include a preset platform-info file, instead of defining the keys directly here. # Note that you cannot use !include and define keys at the same time. - !include $LISA_HOME/lisa/platforms/juno_r0.yml + # !include $LISA_HOME/lisa/platforms/juno_r0.yml # conf: # rtapp: # # Calibration mapping of CPU numbers to calibration value for rtapp diff --git a/tools/bisector/bisector/bisector.py b/tools/bisector/bisector/bisector.py index 713cb4333c60c4c86f774cf7cbc94d433f418d4c..4435768bbce75062e316ee9ca893f26525e3dd0f 100755 --- a/tools/bisector/bisector/bisector.py +++ b/tools/bisector/bisector/bisector.py @@ -105,6 +105,16 @@ def mask_signals(unblock=False): signal.SIGHUP, }) +def filter_keys(mapping, remove=None, keep=None): + return { + k: v + for k, v in mapping.items() + if ( + (remove is None or k not in remove) + and (keep is None or k in keep) + ) + } + sig_exception_lock = threading.Lock() def raise_sig_exception(sig, frame): """Turn some signals into exceptions that can be caught by user code.""" @@ -1768,7 +1778,8 @@ class ExekallLISATestStep(ShellStep): __init__ = dict( compress_artifact = BoolParam('compress the exekall artifact directory in an archive'), upload_artifact = BoolParam('upload the exekall artifact directory to Artifactorial as the execution goes, and delete the local archive.'), - **StepBase.options['__init__'], + # Some options are not supported + **filter_keys(StepBase.options['__init__'], remove={'trials'}), ), report_results = dict( verbose = StepBase.options['report_results']['verbose'], @@ -1797,6 +1808,7 @@ class ExekallLISATestStep(ShellStep): upload_artifact = Default, **kwargs ): + kwargs['trials'] = 1 super().__init__(**kwargs) self.upload_artifact = upload_artifact @@ -1808,11 +1820,11 @@ class ExekallLISATestStep(ShellStep): self.compress_artifact = compress_artifact def run(self, i_stack, service_hub): - # Add a level of UUID under the root, so we can handle multiple trials - artifact_path = os.path.join( - os.getenv('EXEKALL_ARTIFACT_ROOT', './exekall_artifact'), - uuid.uuid4().hex, - ) + artifact_path = os.getenv( + 'EXEKALL_ARTIFACT_ROOT', + # default value + './exekall_artifact' + ), # This also strips the trailing /, which is needed later on when # archiving the artifact. @@ -1824,9 +1836,6 @@ class ExekallLISATestStep(ShellStep): 'EXEKALL_ARTIFACT_ROOT': str(artifact_path), } - if self.trials > 1: - warn("More than one trials requested for exekall LISA test, only the last trial's xUnit XML file will be used.") - res_list = self._run_cmd(i_stack, env=env) ret = res_list[-1][0] diff --git a/tools/exekall/exekall/_utils.py b/tools/exekall/exekall/_utils.py index 6e33b24038476257b6c11c1046a8c5dbf4de2449..66c7ee259c03a6ce0cc4d9181d01d7c1bfd1859e 100644 --- a/tools/exekall/exekall/_utils.py +++ b/tools/exekall/exekall/_utils.py @@ -16,26 +16,30 @@ # limitations under the License. # -import types -import uuid -import inspect -import functools -import fnmatch import collections import contextlib +import fnmatch +import functools +import gc import importlib +import inspect import io import itertools import logging import pathlib import pickle +import subprocess import sys +import tempfile import traceback +import types +import uuid +import glob class NotSerializableError(Exception): pass -def get_class_from_name(cls_name, module_map): +def get_class_from_name(cls_name, module_map=sys.modules): possible_mod_set = { mod_name for mod_name in module_map.keys() @@ -78,7 +82,7 @@ def get_mro(cls): assert isinstance(cls, type) return inspect.getmro(cls) -def get_name(obj, full_qual=True, qual=True): +def get_name(obj, full_qual=True, qual=True, desugar_cls_meth=False): # full_qual enabled implies qual enabled _qual = qual or full_qual # qual disabled implies full_qual disabled @@ -292,10 +296,9 @@ def infer_mod_name(python_src): is_package = True, ) - module_name = '.'.join(( - ('.'.join(module_parents[0].parts)), - module_basename - )) + module_dotted_path = list(module_parents[0].parts) + [module_basename] + module_name = '.'.join(module_dotted_path) + else: module_name = get_module_basename(python_src) @@ -309,10 +312,11 @@ def find_customization_module_set(module_set): i += 1 yield '.'.join(l[:i]) - try: + # Exception raised changed in 3.7: + # https://docs.python.org/3/library/importlib.html#importlib.util.find_spec + if sys.version_info >= (3, 7): import_excep = ModuleNotFoundError - # Python < 3.6 - except NameError: + else: import_excep = AttributeError package_names_list = [ @@ -340,8 +344,23 @@ def find_customization_module_set(module_set): return customization_module_set +def import_paths(paths): + def import_it(path): + # Recursively import all modules when passed folders + if path.is_dir(): + for python_src in glob.iglob(str(path/'**'/'*.py'), recursive=True): + yield import_file(python_src) + # If passed a file, just import it directly + else: + yield import_file(path) + + return set(itertools.chain.from_iterable( + import_it(pathlib.Path(path)) + for path in paths + )) + def import_file(python_src, module_name=None, is_package=False): - python_src = pathlib.Path(python_src) + python_src = pathlib.Path(python_src).resolve() # Directly importing __init__.py does not really make much sense and may # even break, so just import its package instead. @@ -411,19 +430,82 @@ def import_file(python_src, module_name=None, is_package=False): importlib.invalidate_caches() return module -def flatten_nested_seq(seq): - return list(itertools.chain.from_iterable(seq)) +def flatten_seq(seq, levels=1): + if levels == 0: + return seq + else: + seq = list(itertools.chain.from_iterable(seq)) + return flatten_seq(seq, levels=levels - 1) + +def take_first(iterable): + for i in iterable: + return i + return NoValue + +class _NoValueType: + # Use a singleton pattern to make sure that even deserialized instances + # will be the same object + def __new__(cls): + try: + return cls._instance + except AttributeError: + obj = super().__new__(cls) + cls._instance = obj + return obj + + def __eq__(self, other): + return isinstance(other, _NoValueType) + + def __hash__(self): + return 0 + + def __bool__(self): + return False + + def __repr__(self): + return 'NoValue' -def load_serial_from_db(db, uuid_seq=None, type_pattern_seq=None): + def __eq__(self, other): + return type(self) is type(other) - def uuid_predicate(serial): - return ( - serial.value_uuid in uuid_seq - or serial.excep_uuid in uuid_seq - ) +NoValue = _NoValueType() + + +class RestartableIter: + """ + Wrap an iterator to give a new iterator that is restartable. + """ + def __init__(self, it): + self.values = [] + + # Wrap the iterator to update the memoized values + def wrapped(it): + for x in it: + self.values.append(x) + yield x + + self.it = wrapped(it) + + def __iter__(self): + return self + + def __next__(self): + try: + return next(self.it) + except StopIteration: + # Use the stored values the next time we try to get an + # itertor again + self.it = iter(self.values) + raise - def type_pattern_predicate(serial): - return match_base_cls(type(serial.value), type_pattern_seq) + +def get_froz_val_set_set(db, uuid_seq=None, type_pattern_seq=None): + + def uuid_predicate(froz_val): + return froz_val.uuid in uuid_seq + + def type_pattern_predicate(froz_val): + return match_base_cls(type(froz_val.value), type_pattern_seq) if type_pattern_seq and not uuid_seq: predicate = type_pattern_predicate @@ -432,13 +514,13 @@ def load_serial_from_db(db, uuid_seq=None, type_pattern_seq=None): predicate = uuid_predicate elif not uuid_seq and not type_pattern_seq: - predicate = lambda serial: True + predicate = lambda froz_val: True else: - def predicate(serial): - return uuid_predicate(serial) and type_pattern_predicate(serial) + def predicate(froz_val): + return uuid_predicate(froz_val) and type_pattern_predicate(froz_val) - return db.obj_store.get_by_predicate(predicate) + return db.get_by_predicate(predicate, flatten=False, deduplicate=True) def match_base_cls(cls, pattern_list): # Match on the name of the class of the object and all its base classes @@ -446,10 +528,7 @@ def match_base_cls(cls, pattern_list): base_cls_name = get_name(base_cls, full_qual=True) if not base_cls_name: continue - if any( - fnmatch.fnmatch(base_cls_name, pattern) - for pattern in pattern_list - ): + if match_name(base_cls_name, pattern_list): return True return False @@ -457,10 +536,59 @@ def match_base_cls(cls, pattern_list): def match_name(name, pattern_list): if name is None: return False - return any( - fnmatch.fnmatch(name, pattern) + + if not pattern_list: + return False + + neg_patterns = { + pattern[1:] for pattern in pattern_list - ) + if pattern.startswith('!') + } + + pos_patterns = { + pattern + for pattern in pattern_list + if not pattern.startswith('!') + } + + invert = lambda x: not x + identity = lambda x: x + + def check(pattern_set, f): + if pattern_set: + ok = any( + fnmatch.fnmatch(name, pattern) + for pattern in pattern_set + ) + return f(ok) + else: + return True + + return (check(pos_patterns, identity) and check(neg_patterns, invert)) + +def get_common_base(cls_list): + # MRO in which "object" will appear first + def rev_mro(cls): + return reversed(inspect.getmro(cls)) + + def common(cls1, cls2): + # Get the most derived class that is in common in the MRO of cls1 and + # cls2 + for b1, b2 in itertools.takewhile( + lambda b1_b2: b1_b2[0] is b1_b2[1], + zip(rev_mro(cls1), rev_mro(cls2)) + ): + pass + return b1 + + return functools.reduce(common, cls_list) + +def get_subclasses(cls): + subcls_set = {cls} + for subcls in cls.__subclasses__(): + subcls_set.update(get_subclasses(subcls)) + return subcls_set def get_recursive_module_set(module_set, package_set): """Retrieve the set of all modules recurisvely imported from the modules in @@ -493,3 +621,40 @@ def _get_recursive_module_set(module, module_set, package_set): _get_recursive_module_set(imported_module, module_set, package_set) +@contextlib.contextmanager +def disable_gc(): + """ + Context manager to disable garbage collection. + + This can result in significant speed-up in code creating a lot of objects, + like ``pickle.load()``. + """ + if not gc.isenabled(): + yield + return + + gc.disable() + try: + yield + finally: + gc.enable() + +def render_graphviz(expr): + graphviz = expr.get_structure(graphviz=True) + with tempfile.NamedTemporaryFile('wt') as f: + f.write(graphviz) + f.flush() + try: + svg = subprocess.check_output( + ['dot', f.name, '-Tsvg'], + stderr=subprocess.DEVNULL, + ).decode('utf-8') + # If "dot" is not installed + except FileNotFoundError: + pass + except subprocess.CalledProcessError as e: + debug('dot failed to execute: {}'.format(e)) + else: + return (True, svg) + + return (False, graphviz) diff --git a/tools/exekall/exekall/customization.py b/tools/exekall/exekall/customization.py index a180606be1e3cc1c1500a7bbce22530d16049783..60d55cae93568887cd87d36b3511575683f4d8e3 100644 --- a/tools/exekall/exekall/customization.py +++ b/tools/exekall/exekall/customization.py @@ -18,8 +18,8 @@ import numbers -from exekall.engine import NoValue, StorageDB -from exekall.utils import out, get_name +from exekall.engine import ValueDB +from exekall.utils import out, get_name, NoValue, get_subclasses class AdaptorBase: name = 'default' @@ -40,26 +40,24 @@ class AdaptorBase: tags = {} return tags - load_db = None - def update_expr_data(self, expr_data): return - def filter_op_pool(self, op_pool): + def filter_op_set(self, op_set): return { - op for op in op_pool + op for op in op_set # Only select operators with non-empty parameter list. This # rules out all classes __init__ that do not take parameter, as # they are typically not interesting to us. if op.get_prototype()[0] } - def get_prebuilt_list(self): - return [] + def get_prebuilt_set(self): + return set() - def get_hidden_callable_set(self, op_map): - self.hidden_callable_set = set() - return self.hidden_callable_set + def get_hidden_op_set(self, op_set): + self.hidden_op_set = set() + return self.hidden_op_set @staticmethod def register_cli_param(parser): @@ -70,10 +68,11 @@ class AdaptorBase: return {'*Result'} def resolve_cls_name(self, goal): - return utils.get_class_from_name(goal, sys.modules) + return utils.get_class_from_name(goal) - def load_db(self, db_path): - return StorageDB.from_path(db_path) + @staticmethod + def load_db(*args, **kwargs): + return ValueDB.from_path(*args, **kwargs) def finalize_expr(self, expr): pass @@ -81,19 +80,21 @@ class AdaptorBase: def result_str(self, result): val = result.value if val is NoValue or val is None: - failed_parents = result.get_failed_expr_vals() - for failed_parent in failed_parents: + for failed_parent in result.get_excep(): excep = failed_parent.excep return 'EXCEPTION ({type}): {msg}'.format( - type = get_name(type(excep)), + type = get_name(type(excep), full_qual=False), msg = excep ) - return 'No result computed' + return 'No value computed' else: return str(val) - def process_results(self, result_map): - hidden_callable_set = self.hidden_callable_set + def get_summary(self, result_map): + hidden_callable_set = { + op.callable_ + for op in self.hidden_op_set + } # Get all IDs and compute the maximum length to align the output result_id_map = { @@ -108,27 +109,32 @@ class AdaptorBase: max_id_len = len(max(result_id_map.values(), key=len)) + summary = [] for expr, result_list in result_map.items(): for result in result_list: msg = self.result_str(result) msg = msg + '\n' if '\n' in msg else msg - out('{id:<{max_id_len}} {result}'.format( + summary.append('{id:<{max_id_len}} {result}'.format( id=result_id_map[result], result=msg, max_id_len=max_id_len, )) + return '\n'.join(summary) @classmethod def get_adaptor_cls(cls, name=None): - subcls_list = list(cls.__subclasses__()) - if len(subcls_list) > 1 and not name: - raise ValueError('An adaptor name must be specified if there is more than one adaptor to choose from') + subcls_list = list(get_subclasses(cls) - {cls}) + if not name: + if len(subcls_list) > 1: + raise ValueError('An adaptor name must be specified if there is more than one adaptor to choose from') + else: + if len(subcls_list) > 0: + return subcls_list[0] + else: + return cls for subcls in subcls_list: - if name: - if subcls.name == name: - return subcls - else: + if subcls.name == name: return subcls return None diff --git a/tools/exekall/exekall/engine.py b/tools/exekall/exekall/engine.py index a1fadf03ab72e807ea79e9cd3e23c273baa2b321..512b01afceffe33fb3594a528e5ceae9f6a702f9 100644 --- a/tools/exekall/exekall/engine.py +++ b/tools/exekall/exekall/engine.py @@ -22,45 +22,19 @@ from collections import OrderedDict import copy import itertools import functools -import gzip +import lzma import pathlib import contextlib +import pickle import pprint - -import ruamel.yaml +import pickletools import exekall._utils as utils - -def take_first(iterable): - for i in iterable: - return i - return NoValue +from exekall._utils import NoValue class NoOperatorError(Exception): pass -class _NoValueType: - # Use a singleton pattern to make sure that even deserialized instances - # will be the same object - def __new__(cls): - try: - return cls._instance - except AttributeError: - obj = super().__new__(cls) - cls._instance = obj - return obj - - def __bool__(self): - return False - - def __repr__(self): - return 'NoValue' - - def __eq__(self, other): - return type(self) is type(other) - -NoValue = _NoValueType() - class IndentationManager: def __init__(self, style): self.style = style @@ -75,20 +49,73 @@ class IndentationManager: def __str__(self): return str(self.style) * self.level -class StorageDB: - _yaml = ruamel.yaml.YAML(typ='unsafe') +class ValueDB: + # Version 4 is available since Python 3.4 and improves a bit loading and + # dumping speed. + PICKLE_PROTOCOL = 4 + + def __init__(self, froz_val_seq_list): + # Avoid storing duplicate FrozenExprVal sharing the same value/excep + # UUID + self.froz_val_seq_list = self._dedup_froz_val_seq_list(froz_val_seq_list) @classmethod - def _init_yaml(cls): - """Needs to be called only once""" - yaml = cls._yaml + def _dedup_froz_val_seq_list(cls, froz_val_seq_list): + """ + Avoid keeping :class:`FrozenExprVal` that share the same value or + excep UUID, since they are duplicates of each-other. + """ - yaml.allow_unicode = True - yaml.default_flow_style = False - yaml.indent = 4 + # First pass: find all frozen values corresponding to a given UUID + uuid_map = {} + def update_uuid_map(froz_val): + uuid_map.setdefault(froz_val.uuid, set()).add(froz_val) + return froz_val + cls._froz_val_dfs(froz_val_seq_list, update_uuid_map) + + # Make sure no deduplication will occur on None, as it is used as a + # marker when no exception was raised or when no value was available. + uuid_map[(None, None)] = set() + + # Select one FrozenExprVal for each UUID pair + def select_froz_val(froz_val_set): + candidates = [ + froz_val + for froz_val in froz_val_set + # We discard candidates that have no parameters, as they + # contain less information than the ones that do. This is + # typically the case for PrebuiltOperator values + if froz_val.param_map + ] - def __init__(self, obj_store): - self.obj_store = obj_store + # At this point, there should be no more than one "original" value, + # the other candidates were just values of PrebuiltOperator, or are + # completely equivalent to the original value + + if candidates: + return candidates[0] + # If there was no better candidate, just return the first one + else: + return utils.take_first(froz_val_set) + + uuid_map = { + uuid_pair: select_froz_val(froz_val_set) + for uuid_pair, froz_val_set in uuid_map.items() + } + + # Second pass: only keep one frozen value for each UUID + def rewrite_graph(froz_val): + return uuid_map[froz_val.uuid] + + return cls._froz_val_dfs(froz_val_seq_list, rewrite_graph) + + @classmethod + def merge(cls, db_seq): + froz_val_seq_list = list(itertools.chain(*( + db.froz_val_seq_list + for db in db_seq + ))) + return cls(froz_val_seq_list) @classmethod def from_path(cls, path, relative_to=None): @@ -98,327 +125,195 @@ class StorageDB: relative_to = pathlib.Path(relative_to).parent path = pathlib.Path(relative_to, path) - with gzip.open(str(path), 'rt', encoding='utf-8') as f: - db = cls._yaml.load(f) + with lzma.open(str(path), 'rb') as f: + # Disabling garbage collection while loading result in significant + # speed improvement, since it creates a lot of new objects in a + # very short amount of time. + with utils.disable_gc(): + db = pickle.load(f) assert isinstance(db, cls) return db - def to_path(self, path): - with gzip.open(str(path), 'wt', encoding='utf-8') as f: - self._yaml.dump(self, f) - - # Having it there shortens the output of the generated scripts and makes - # them more readable while avoiding to expose to much of the StorageDB - # internals - def by_uuid(self, *args, **kwargs): - return self.obj_store.by_uuid(*args, **kwargs) - -StorageDB._init_yaml() - -class ObjectStore: - def __init__(self, serial_seq_list, db_var_name='db'): - self.db_var_name = db_var_name - self.serial_seq_list = serial_seq_list - - def get_value_snippet(self, value): - _, id_uuid_map = self.get_indexes() - return '{db}.by_uuid({key})'.format( - db = self.db_var_name, - key = repr(id_uuid_map[id(value)]) - ) - - def by_uuid(self, uuid): - uuid_value_map, _ = self.get_indexes() - return uuid_value_map[uuid] - - # Since the content of the cache is not serialized, the maps will be - # regenerated when the object is restored. - @utils.once - def get_indexes(self): - uuid_value_map = dict() - id_uuid_map = dict() - - def update_map(serial_val): - for uuid_, val in ( - (serial_val.value_uuid, serial_val.value), - (serial_val.excep_uuid, serial_val.excep), - ): - uuid_value_map[uuid_] = val - id_uuid_map[id(val)] = uuid_ - - self._serial_val_dfs(update_map) - - return (uuid_value_map, id_uuid_map) - - def _serial_val_dfs(self, callback): - for serial_seq in self.serial_seq_list: - for serial_val in serial_seq: - self._do_serial_val_dfs(serial_val, callback) - - def _do_serial_val_dfs(cls, serial_val, callback): - callback(serial_val) - for serial_val in serial_val.param_expr_val_map.values(): - cls._do_serial_val_dfs(serial_val, callback) - - def get_all(self): - serial_seq_set = self.get_by_predicate(lambda serial: True) - all_set = set() - for serial_seq in serial_seq_set: - all_set.update(serial_seq) - return all_set - - def get_by_predicate(self, predicate): - """ - Return a set of sets, containing objects matching the predicate. - There is a set for each computed expression in the store, but the same - object will not be included twice (in case it is refered by different - expressions). + def to_path(self, path, optimize=True): """ - serial_seq_set = set() + Write the DB to the given file. - # When we reload instances of a class from the DB, we don't - # want anything else to be able to produce it, since we want to - # run on that existing data set + :param path: path to file to write the DB into + :type path: pathlib.Path or str - for serial_seq in self.serial_seq_list: - serial_set = set() - for serial in serial_seq: - serial_set.update(serial.get_parent_set(predicate)) + :param optimize: Optimize the representation of the DB. This may + increase the dump time, but should speed-up loading/file size. + :type optimize: bool + """ + if optimize: + bytes_ = pickle.dumps(self, protocol=self.PICKLE_PROTOCOL) + bytes_ = pickletools.optimize(bytes_) + dumper = lambda f: f.write(bytes_) + else: + dumper = lambda f: pickle.dump(self, f, protocol=self.PICKLE_PROTOCOL) - serial_seq_set.add(frozenset(serial_set)) + with lzma.open(str(path), 'wb') as f: + dumper(f) - return serial_seq_set + @property + @utils.once + def _uuid_map(self): + uuid_map = dict() -class CycleError(Exception): - pass + def update_map(froz_val): + uuid_map[froz_val.uuid] = froz_val + return froz_val -class ExpressionWrapper: - def __init__(self, expr): - self.expr = expr + self._froz_val_dfs(self.froz_val_seq_list, update_map) - def __getattr__(self, attr): - return getattr(self.expr, attr) + return uuid_map @classmethod - def build_expr_list(cls, result_op_seq, op_map, cls_map, - non_produced_handler='raise', cycle_handler='raise'): - op_map = copy.copy(op_map) - cls_map = { - cls: compat_cls_set - for cls, compat_cls_set in cls_map.items() - # If there is at least one compatible subclass that is produced, we - # keep it, otherwise it will mislead _build_expr into thinking the - # class can be built where in fact it cannot - if compat_cls_set & op_map.keys() - } - internal_cls_set = {Consumer, ExprData} - for internal_cls in internal_cls_set: - op_map[internal_cls] = { - Operator(internal_cls, non_reusable_type_set=internal_cls_set) - } - cls_map[internal_cls] = [internal_cls] - - expr_list = list() - for result_op in result_op_seq: - expr_gen = cls._build_expr(result_op, op_map, cls_map, - op_stack = [], - non_produced_handler=non_produced_handler, - cycle_handler=cycle_handler, + def _froz_val_dfs(cls, froz_val_seq_list, callback): + return [ + FrozenExprValSeq( + froz_val_list=[ + cls._do_froz_val_dfs(froz_val, callback) + for froz_val in froz_val_seq + ], + param_map={ + param: cls._do_froz_val_dfs(froz_val, callback) + for param, froz_val in froz_val_seq.param_map.items() + } ) - for expr in expr_gen: - if expr.validate_expr(op_map): - expr_list.append(expr) - - return expr_list + for froz_val_seq in froz_val_seq_list + ] @classmethod - def _build_expr(cls, op, op_map, cls_map, op_stack, non_produced_handler, cycle_handler): - new_op_stack = [op] + op_stack - # We detected a cyclic dependency - if op in op_stack: - if cycle_handler == 'ignore': - return - elif callable(cycle_handler): - cycle_handler(tuple(op.callable_ for op in new_op_stack)) - return - elif cycle_handler == 'raise': - raise CycleError('Cyclic dependency found: {path}'.format( - path = ' -> '.join( - op.name for op in new_op_stack - ) - )) - else: - raise ValueError('Invalid cycle_handler') + def _do_froz_val_dfs(cls, froz_val, callback): + updated_froz_val = callback(froz_val) + updated_froz_val.param_map = { + param: cls._do_froz_val_dfs(param_froz_val, callback) + for param, param_froz_val in updated_froz_val.param_map.items() + } + return updated_froz_val - op_stack = new_op_stack + def get_by_uuid(self, uuid): + return self._uuid_map[uuid] - param_map, produced = op.get_prototype() - if param_map: - param_list, cls_list = zip(*param_map.items()) - # When no parameter is needed - else: - yield ExpressionWrapper(Expression(op, OrderedDict())) - return - - # Build all the possible combinations of types suitable as parameters - cls_combis = [cls_map.get(cls, list()) for cls in cls_list] + def get_by_predicate(self, predicate, flatten=True, deduplicate=False): + """ + Get objects matching the predicate. - # Only keep the classes for "self" on which the method can be applied - if op.is_method: - cls_combis[0] = [ - cls for cls in cls_combis[0] - # If the method with the same name would resolve to "op", then - # we keep this class as a candidate for "self", otherwise we - # discard it - if getattr(cls, op.callable_.__name__, None) is op.callable_ - ] + :param flatten: If False, return a set of frozenset of objects. + There is a frozenset set for each expression result that shared + their parameters. If False, the top-level set is flattened into a + set of objects matching the predicate. + :type flatten: bool - # Check that some produced classes are available for every parameter - ignored_indices = set() - for param, wanted_cls, available_cls in zip(param_list, cls_list, cls_combis): - if not available_cls: - # If that was an optional parameter, just ignore it without - # throwing an exception since it has a default value - if param in op.optional_param: - ignored_indices.add(param_list.index(param)) - else: - if non_produced_handler == 'ignore': - return - elif callable(non_produced_handler): - non_produced_handler(wanted_cls.__qualname__, op.name, param, - tuple(op.resolved_callable for op in op_stack) - ) - return - elif non_produced_handler == 'raise': - raise NoOperatorError('No operator can produce instances of {cls} needed for {op} (parameter "{param}" along path {path})'.format( - cls = wanted_cls.__qualname__, - op = op.name, - param = param, - path = ' -> '.join( - op.name for op in op_stack - ) - )) - else: - raise ValueError('Invalid non_produced_handler') + :param deduplicate: If True, there won't be duplicates across nested + sets. + :type deduplicate: bool + """ + froz_val_set_set = set() - param_list = utils.remove_indices(param_list, ignored_indices) - cls_combis = utils.remove_indices(cls_combis, ignored_indices) + # When we reload instances of a class from the DB, we don't + # want anything else to be able to produce it, since we want to + # run on that existing data set - param_list_len = len(param_list) + # Make sure we don't select the same froz_val twice + if deduplicate: + visited = set() + def wrapped_predicate(froz_val): + if froz_val in visited: + return False + else: + visited.add(froz_val) + return predicate(froz_val) + else: + wrapped_predicate = predicate + + for froz_val_seq in self.froz_val_seq_list: + froz_val_set = set() + for froz_val in itertools.chain( + # traverse all values, including the ones from the + # parameters, even when there was no value computed + # (because of a failed parent for example) + froz_val_seq, froz_val_seq.param_map.values() + ): + froz_val_set.update(froz_val.get_by_predicate(wrapped_predicate)) + + froz_val_set_set.add(frozenset(froz_val_set)) + + if flatten: + return set(utils.flatten_seq(froz_val_set_set)) + else: + return froz_val_set_set - # For all possible combinations of types - for cls_combi in itertools.product(*cls_combis): - cls_combi = list(cls_combi) + def get_all(self, **kwargs): + return self.get_by_predicate(lambda froz_val: True, **kwargs) - # Some classes may not be produced, but another combination - # with containing a subclass of it may actually be produced so we can - # just ignore that one. - op_combis = [ - op_map[cls] for cls in cls_combi - if cls in op_map - ] + def get_by_type(self, cls, include_subclasses=True, **kwargs): + if include_subclasses: + predicate = lambda froz_val: isinstance(froz_val.value, cls) + else: + predicate = lambda froz_val: type(froz_val.value) is cls + return self.get_by_predicate(predicate, **kwargs) + + def get_by_id(self, id_, qual=False, full_qual=False, **kwargs): + def predicate(froz_val): + return utils.match_name( + froz_val.get_id(qual=qual, full_qual=full_qual), + [id_] + ) - # Build all the possible combinations of operators returning these - # types - for op_combi in itertools.product(*op_combis): - op_combi = list(op_combi) + return self.get_by_predicate(predicate, **kwargs) - # Get all the possible ways of calling these operators - param_combis = itertools.product(*(cls._build_expr( - param_op, op_map, cls_map, - op_stack, non_produced_handler, cycle_handler, - ) for param_op in op_combi - )) +class ScriptValueDB: + def __init__(self, db, var_name='db'): + self.db = db + self.var_name = var_name - for param_combi in param_combis: - param_map = OrderedDict(zip(param_list, param_combi)) + def get_snippet(self, expr_val, attr): + return '{db}.get_by_uuid({uuid}).{attr}'.format( + db=self.var_name, + uuid=repr(expr_val.uuid), + attr=attr, + ) - # If all parameters can be built, carry on - if len(param_map) == param_list_len: - yield ExpressionWrapper( - Expression(op, param_map) - ) +class CycleError(Exception): + pass -class Expression: - def __init__(self, op, param_map, data=None): +class ExpressionBase: + def __init__(self, op, param_map): self.op = op # Map of parameters to other Expression self.param_map = param_map - self.data = data if data is not None else dict() - self.data_uuid = utils.create_uuid() - self.uuid = utils.create_uuid() - - self.discard_result() - - def validate_expr(self, op_map): - type_map, valid = self._get_type_map() - if not valid: - return False - - # Check that the Expression does not involve 2 classes that are compatible - cls_bags = [set(cls_list) for cls_list in op_map.values()] - cls_used = set(type_map.keys()) - for cls1, cls2 in itertools.product(cls_used, repeat=2): - for cls_bag in cls_bags: - if cls1 in cls_bag and cls2 in cls_bag: - return False - - return True - - def _get_type_map(self): - type_map = dict() - return (type_map, self._populate_type_map(type_map)) - def _populate_type_map(self, type_map): - value_type = self.op.value_type - # If there was already an Expression producing that type, the Expression - # is not valid - found_callable = type_map.get(value_type) - if found_callable is not None and found_callable is not self.op.callable_: - return False - type_map[value_type] = self.op.callable_ + @classmethod + def cse(cls, expr_list): + """ + Apply a flavor of common subexpressions elimination to the + Expression. + """ - for param_expr in self.param_map.values(): - if not param_expr._populate_type_map(type_map): - return False - return True + expr_map = {} + return [ + expr._cse(expr_map) + for expr in expr_list + ] + def _cse(self, expr_map): + # Deep first + self.param_map = { + param: param_expr._cse(expr_map=expr_map) + for param, param_expr in self.param_map.items() + } - def get_param_map(self, reusable): - reusable = bool(reusable) - return OrderedDict( - (param, param_expr) - for param, param_expr - in self.param_map.items() - if bool(param_expr.op.reusable) == reusable + key = ( + self.op.callable_, + # get a nested tuple sorted by param name with the shape: + # ((param, val), ...) + tuple(sorted(self.param_map.items(), key=lambda k_v: k_v[0])) ) - def get_all_vals(self): - for result in self.result_list: - yield from result.value_list - - def find_result_list(self, param_expr_val_map): - def value_map(expr_val_map): - return OrderedDict( - # Extract the actual value from ExprValue - (param, expr_val.value) - for param, expr_val in expr_val_map.items() - ) - param_expr_val_map = value_map(param_expr_val_map) - - # Find the results that are matching the param_expr_val_map - return [ - result - for result in self.result_list - # Check if param_expr_val_map is a subset of the param_expr_val_map - # of the ExprValue. That allows checking for reusable parameters - # only. - if param_expr_val_map.items() <= value_map(result.param_expr_val_map).items() - ] - - def discard_result(self): - self.result_list = list() + return expr_map.setdefault(key, self) def __repr__(self): return ''.format( @@ -426,7 +321,13 @@ class Expression: id = hex(id(self)) ) - def pretty_structure(self, full_qual=True, indent=1): + def get_structure(self, full_qual=True, graphviz=False): + if graphviz: + return self._get_graphviz_structure(full_qual, level=0, visited=set()) + else: + return self._get_structure(full_qual=full_qual) + + def _get_structure(self, full_qual=True, indent=1): indent_str = 4 * ' ' * indent if isinstance(self.op, PrebuiltOperator): @@ -436,12 +337,11 @@ class Expression: out = '{op_name} ({value_type_name})'.format( op_name = op_name, - value_type_name = utils.get_name(self.op.value_type, full_qual=full_qual) -, + value_type_name = utils.get_name(self.op.value_type, full_qual=full_qual), ) if self.param_map: out += ':\n'+ indent_str + ('\n'+indent_str).join( - '{param}: {desc}'.format(param=param, desc=desc.pretty_structure( + '{param}: {desc}'.format(param=param, desc=desc._get_structure( full_qual=full_qual, indent=indent+1 )) @@ -449,11 +349,71 @@ class Expression: ) return out - def get_failed_expr_vals(self): - for expr_val in self.get_all_vals(): - yield from expr_val.get_failed_expr_vals() + def _get_graphviz_structure(self, full_qual, level, visited): + if self in visited: + return '' + else: + visited.add(self) + + if isinstance(self.op, PrebuiltOperator): + op_name = '' + else: + op_name = self.op.get_name(full_qual=True) + + # Use the Python id as it is guaranteed to be unique during the lifetime of + # the object, so it is a good candidate to refer to a node + uid = id(self) - def get_id(self, *args, marked_expr_val_set=None, mark_excep=False, hidden_callable_set=None, **kwargs): + src_file, src_line = self.op.src_loc + if src_file and src_line: + src_loc = '({}:{})'.format(src_file, src_line) + else: + src_loc = '' + + out = ['{uid} [label="{op_name} {reusable}\\ntype: {value_type_name}\\n{loc}"]'.format( + uid=uid, + op_name=op_name, + reusable='(reusable)' if self.op.reusable else '(non-reusable)', + value_type_name=utils.get_name(self.op.value_type, full_qual=full_qual), + loc=src_loc, + )] + if self.param_map: + for param, param_expr in self.param_map.items(): + out.append( + '{param_uid} -> {uid} [label="{param}"]'.format( + param_uid=id(param_expr), + uid=uid, + param=param, + ) + ) + + out.append( + param_expr._get_graphviz_structure( + full_qual=full_qual, + level=level+1, + visited=visited, + ) + ) + + if level == 0: + title = 'Structure of ' + self.get_id(qual=False) + node_out = 'digraph structure {{\n{}\nlabel="' + title + '"\n}}' + else: + node_out = '{}' + # dot seems to dislike empty line with just ";" + return node_out.format(';\n'.join(line for line in out if line.strip())) + + def get_id(self, *args, marked_expr_val_set=set(), **kwargs): + id_, marker = self._get_id(*args, + marked_expr_val_set=marked_expr_val_set, + **kwargs + ) + if marked_expr_val_set: + return '\n'.join((id_, marker)) + else: + return id_ + + def _get_id(self, with_tags=True, full_qual=True, qual=True, style=None, expr_val=None, marked_expr_val_set=None, hidden_callable_set=None): if hidden_callable_set is None: hidden_callable_set = set() @@ -461,193 +421,143 @@ class Expression: # to the ID. It is mostly an implementation detail. hidden_callable_set.update((Consumer, ExprData)) - # Mark all the values that failed to be computed because of an - # exception - if mark_excep: - marked_expr_val_set = set(self.get_failed_expr_vals()) - - for id_, marker in self._get_id( - marked_expr_val_set=marked_expr_val_set, hidden_callable_set=hidden_callable_set, - *args, **kwargs - ): - if marked_expr_val_set: - yield '\n'.join((id_, marker)) - else: - yield id_ - - def _get_id(self, with_tags=True, full_qual=True, qual=True, expr_val=None, marked_expr_val_set=None, hidden_callable_set=None): - # When asked about NoValue, it means the caller did not have any value - # computed for that parameter, but still wants an ID. Obviously, it - # cannot have any tag since there is no ExprValue available to begin - # with. - if expr_val is NoValue: - with_tags = False - - # No specific value was asked for, so we will cover the IDs of all - # values - if expr_val is None or expr_val is NoValue: - def grouped_expr_val_list(): - # Make sure we yield at least once even if no computed value - # is available, so _get_id() is called at least once - if (not self.result_list) or (not with_tags): - yield (OrderedDict(), []) - else: - for result in self.result_list: - yield (result.param_expr_val_map, result.value_list) - + if expr_val is None: + param_map = dict() # If we were asked about the ID of a specific value, make sure we # don't explore other paths that lead to different values else: - def grouped_expr_val_list(): - # Only yield the ExprValue we are interested in - yield (expr_val.param_expr_val_map, [expr_val]) - - for param_expr_val_map, value_list in grouped_expr_val_list(): - yield from self._get_id_internal( - param_expr_val_map=param_expr_val_map, - value_list=value_list, - with_tags=with_tags, - marked_expr_val_set=marked_expr_val_set, - hidden_callable_set=hidden_callable_set, - full_qual=full_qual, - qual=qual - ) + param_map = expr_val.param_map + + return self._get_id_internal( + param_map=param_map, + expr_val=expr_val, + with_tags=with_tags, + marked_expr_val_set=marked_expr_val_set, + hidden_callable_set=hidden_callable_set, + full_qual=full_qual, + qual=qual, + style=style, + ) - def _get_id_internal(self, param_expr_val_map, value_list, with_tags, marked_expr_val_set, hidden_callable_set, full_qual, qual): + def _get_id_internal(self, param_map, expr_val, with_tags, marked_expr_val_set, hidden_callable_set, full_qual, qual, style): separator = ':' marker_char = '^' + get_id_kwargs = dict( + full_qual=full_qual, + qual=qual, + style=style + ) if marked_expr_val_set is None: marked_expr_val_set = set() - # We only get the ID's of the parameter ExprValue that lead to the - # ExprValue we are interested in + # We only get the ID's of the parameter ExprVal that lead to the + # ExprVal we are interested in param_id_map = OrderedDict( - (param, take_first(param_expr._get_id( + (param, param_expr._get_id( + **get_id_kwargs, with_tags = with_tags, - full_qual = full_qual, - qual = qual, - # Pass a NoValue when there is no value available, since - # None means all possible IDs (we just want one here). - expr_val = param_expr_val_map.get(param, NoValue), + # Pass None when there is no value available, so we will get + # a non-tagged ID when there is no value computed + expr_val = param_map.get(param), marked_expr_val_set = marked_expr_val_set, hidden_callable_set = hidden_callable_set, - ))) + )) for param, param_expr in self.param_map.items() if ( param_expr.op.callable_ not in hidden_callable_set # If the value is marked, the ID will not be hidden - or param_expr_val_map.get(param) in marked_expr_val_set + or param_map.get(param) in marked_expr_val_set ) ) - def tags_iter(value_list): - if value_list: - for expr_val in value_list: - if with_tags: - tag = expr_val.format_tags() - else: - tag = '' - yield (expr_val, tag) - # Yield at least once without any tag even if there is no computed - # value available + def get_tags(expr_val): + if expr_val is not None: + if with_tags: + tag = expr_val.format_tags() + else: + tag = '' + return tag else: - yield None, '' + return '' def get_marker_char(expr_val): return marker_char if expr_val in marked_expr_val_set else ' ' + tag_str = get_tags(expr_val) + # No parameter to worry about if not param_id_map: - for expr_val, tag_str in tags_iter(value_list): - id_ = self.op.get_id(full_qual=full_qual, qual=qual) + tag_str - marker_str = get_marker_char(expr_val) * len(id_) - yield (id_, marker_str) + id_ = self.op.get_id(**get_id_kwargs) + tag_str + marker_str = get_marker_char(expr_val) * len(id_) + return (id_, marker_str) - # For all ExprValue we were asked about, we will yield an ID + # Recursively build an ID else: - for expr_val, tag_str in tags_iter(value_list): - # Make a copy to be able to pop items from it - param_id_map = copy.copy(param_id_map) - - # Extract the first parameter to always use the prefix - # notation, i.e. its value preceding the ID of the current - # Expression - if param_id_map: - param, (param_id, param_marker) = param_id_map.popitem(last=False) - else: - param_id = '' - param_marker = '' + # Make a copy to be able to pop items from it + param_id_map = copy.copy(param_id_map) - if param_id: - separator_spacing = ' ' * len(separator) - param_str = param_id + separator - else: - separator_spacing = '' - param_str = '' + # Extract the first parameter to always use the prefix + # notation, i.e. its value preceding the ID of the current + # Expression + param, (param_id, param_marker) = param_id_map.popitem(last=False) - op_str = '{op}{tags}'.format( - op = self.op.get_id(full_qual=full_qual, qual=qual), - tags = tag_str, - ) - id_ = '{param_str}{op_str}'.format( - param_str = param_str, - op_str = op_str, - ) - marker_str = '{param_marker}{separator}{op_marker}'.format( - param_marker = param_marker, - separator = separator_spacing, - op_marker = len(op_str) * get_marker_char(expr_val) - ) + if param_id: + separator_spacing = ' ' * len(separator) + param_str = param_id + separator + else: + separator_spacing = '' + param_str = '' - # If there are some remaining parameters, show them in - # parenthesis at the end of the ID - if param_id_map: - param_str = '(' + ','.join( - param + '=' + param_id - for param, (param_id, param_marker) - # Sort by parameter name to have a stable ID - in param_id_map.items() - if param_id - ) + ')' - id_ += param_str - param_marker = ' '.join( - ' ' * (len(param) + 1) + param_marker - for param, (param_id, param_marker) - # Sort by parameter name to have a stable ID - in param_id_map.items() - if param_id - ) + ' ' - - marker_str += ' ' + param_marker - - yield (id_, marker_str) + op_str = '{op}{tags}'.format( + op = self.op.get_id(**get_id_kwargs), + tags = tag_str, + ) + id_ = '{param_str}{op_str}'.format( + param_str = param_str, + op_str = op_str, + ) + marker_str = '{param_marker}{separator}{op_marker}'.format( + param_marker = param_marker, + separator = separator_spacing, + op_marker = len(op_str) * get_marker_char(expr_val) + ) - @classmethod - def get_all_serializable_vals(cls, expr_seq, *args, **kwargs): - serialized_map = dict() - result_list = list() - for expr in expr_seq: - for result in expr.result_list: - result_list.append([ - expr_val._get_serializable(serialized_map, *args, **kwargs) - for expr_val in result.value_list - ]) - - return result_list + # If there are some remaining parameters, show them in + # parenthesis at the end of the ID + if param_id_map: + param_str = '(' + ','.join( + param + '=' + param_id + for param, (param_id, param_marker) + # Sort by parameter name to have a stable ID + in param_id_map.items() + if param_id + ) + ')' + id_ += param_str + param_marker = ' '.join( + ' ' * (len(param) + 1) + param_marker + for param, (param_id, param_marker) + # Sort by parameter name to have a stable ID + in param_id_map.items() + if param_id + ) + ' ' + + marker_str += ' ' + param_marker + return (id_, marker_str) def get_script(self, *args, **kwargs): return self.get_all_script([self], *args, **kwargs) @classmethod - def get_all_script(cls, expr_list, prefix='value', db_path='storage.yml.gz', db_relative_to=None, db_loader=None, obj_store=None): + def get_all_script(cls, expr_list, prefix='value', db_path='VALUE_DB.pickle.xz', db_relative_to=None, db_loader=None, db=None): assert expr_list - if obj_store is None: - serial_list = Expression.get_all_serializable_vals(expr_list) - obj_store = ObjectStore(serial_list) + if db is None: + froz_val_seq_list = FrozenExprValSeq.from_expr_list(expr_list) + script_db = ScriptValueDB(ValueDB(froz_val_seq_list)) + else: + script_db = ScriptValueDB(db) - db_var_name = obj_store.db_var_name def make_comment(txt): joiner = '\n# ' @@ -664,11 +574,9 @@ class Expression: for i, expr in enumerate(expr_list): script += ( '#'*80 + '\n# Computed expressions:' + - ''.join( - make_comment(id_) - for id_ in expr.get_id(mark_excep=True, full_qual=False) - ) + '\n' + - make_comment(expr.pretty_structure()) + '\n\n' + make_comment(expr.get_id(mark_excep=True, full_qual=False)) + + '\n' + + make_comment(expr.get_structure()) + '\n\n' ) idt = IndentationManager(' '*4) @@ -676,29 +584,16 @@ class Expression: result_name, snippet = expr._get_script( reusable_outvar_map = reusable_outvar_map, prefix = prefix + str(i), - obj_store = obj_store, + script_db = script_db, module_name_set = module_name_set, idt = idt, expr_val_set = expr_val_set, consumer_expr_stack = [], ) - # If we can expect eval() to work on the representation, we - # use that - if pprint.isreadable(expr.data): - expr_data = pprint.pformat(expr.data) - else: - # Otherwise, we try to get it from the DB - try: - expr_data = obj_store.get_value_snippet(expr.data) - # If the expr_data was not used when computing subexpressions - # (that may happen if some subrexpressions were already - # computed for an other expression), we just bail out, hoping - # that nothing will need EXPR_DATA to be defined. That should - # not happen often as EXPR_DATA is supposed to stay - # pretty-printable - except KeyError: - expr_data = '{} # cannot be pretty-printed' + # ExprData must be printable to a string representation that can be + # fed back to eval() + expr_data = pprint.pformat(expr.data) expr_data_snippet = cls.EXPR_DATA_VAR_NAME + ' = ' + expr_data + '\n' @@ -715,7 +610,7 @@ class Expression: # Get the name of the customized db_loader if db_loader is None: db_loader_name = '{cls_name}.from_path'.format( - cls_name=utils.get_name(StorageDB, full_qual=True), + cls_name=utils.get_name(ValueDB, full_qual=True), ) else: module_name_set.add(inspect.getmodule(db_loader).__name__) @@ -746,8 +641,8 @@ class Expression: header += '\n\n' - # If there is no ExprValue referenced by that script, we don't need - # to access any StorageDB + # If there is no ExprVal referenced by that script, we don't need + # to access any ValueDB if expr_val_set: if db_relative_to is not None: db_relative_to = ', relative_to='+db_relative_to @@ -755,7 +650,7 @@ class Expression: db_relative_to = '' header += '{db} = {db_loader_name}({path}{db_relative_to})\n'.format( - db = db_var_name, + db = script_db.var_name, db_loader_name = db_loader_name, path = repr(str(db_path)), db_relative_to = db_relative_to @@ -777,7 +672,7 @@ class Expression: reusable_outvar_map[self] = outvar return (outvar, script) - def _get_script_internal(self, reusable_outvar_map, prefix, obj_store, module_name_set, idt, expr_val_set, consumer_expr_stack): + def _get_script_internal(self, reusable_outvar_map, prefix, script_db, module_name_set, idt, expr_val_set, consumer_expr_stack, expr_val_seq_list=[]): def make_method_self_name(expr): return expr.op.value_type.__name__.replace('.', '') @@ -812,7 +707,7 @@ class Expression: obj = getattr(expr_val, attr) utils.is_serializable(obj, raise_excep=True) - # When the ExprValue is from an Expression of the Consumer + # When the ExprVal is from an Expression of the Consumer # operator, we directly print out the name of the function that was # selected since it is not serializable callable_ = expr_val.expr.op.callable_ @@ -821,11 +716,11 @@ class Expression: elif attr == 'value' and callable_ is ExprData: return self.EXPR_DATA_VAR_NAME else: - return obj_store.get_value_snippet(obj) + return script_db.get_snippet(expr_val, attr) - def format_build_param(param_expr_val_map): + def format_build_param(param_map): out = list() - for param, expr_val in param_expr_val_map.items(): + for param, expr_val in param_map.items(): try: value = format_expr_val(expr_val) # Cannot be serialized, so we skip it @@ -856,7 +751,7 @@ class Expression: # The parameter we are trying to compute cannot be computed and we will # just output a skeleton with a placeholder for the user to fill it - is_user_defined = isinstance(self.op, PrebuiltOperator) and not self.result_list + is_user_defined = isinstance(self.op, PrebuiltOperator) and not expr_val_seq_list # Consumer operator is special since we don't compute anything to # get its value, it is just the name of a function @@ -866,7 +761,7 @@ class Expression: else: return (consumer_expr_stack[-2].op.get_name(full_qual=True), '') elif self.op.callable_ is ExprData: - # When we actually have an ExprValue, use it so we have the right + # When we actually have an ExprVal, use it so we have the right # UUID. if expr_val_set: # They should all have be computed using the same ExprData, @@ -874,9 +769,9 @@ class Expression: expr_val_list = [expr_val.value for expr_val in expr_val_set] assert expr_val_list[1:] == expr_val_list[:-1] - expr_data = take_first(expr_val_set) + expr_data = utils.take_first(expr_val_set) return (format_expr_val(expr_data, lambda x:''), '') - # Prior to execution, we don't have an ExprValue yet + # Prior to execution, we don't have an ExprVal yet else: is_user_defined = True @@ -896,12 +791,20 @@ class Expression: # Reusable parameter values are output first, so that non-reusable # parameters will be inside the for loops if any to be recomputed # for every combination of reusable parameters. + + def get_param_map(reusable): + return OrderedDict( + (param, param_expr) + for param, param_expr + in self.param_map.items() + if bool(param_expr.op.reusable) == reusable + ) param_map_chain = itertools.chain( - self.get_param_map(reusable=True).items(), - self.get_param_map(reusable=False).items(), + get_param_map(reusable=True).items(), + get_param_map(reusable=False).items(), ) - first_param = take_first(self.param_map.keys()) + first_param = utils.take_first(self.param_map.keys()) for param, param_expr in param_map_chain: # Rename "self" parameter for more natural-looking output @@ -914,19 +817,19 @@ class Expression: param_prefix = make_var(pretty_param) - # Get the set of ExprValue that were used to compute the - # ExprValue given in expr_val_set + # Get the set of ExprVal that were used to compute the + # ExprVal given in expr_val_set param_expr_val_set = set() for expr_val in expr_val_set: # When there is no value for that parameter, that means it # could not be computed and therefore we skip that result with contextlib.suppress(KeyError): - param_expr_val = expr_val.param_expr_val_map[param] + param_expr_val = expr_val.param_map[param] param_expr_val_set.add(param_expr_val) - # Do a deep first search traversal of the expression. + # Do a deep first traversal of the expression. param_outvar, param_out = param_expr._get_script( - reusable_outvar_map, param_prefix, obj_store, module_name_set, idt, + reusable_outvar_map, param_prefix, script_db, module_name_set, idt, param_expr_val_set, consumer_expr_stack = consumer_expr_stack + [self], ) @@ -962,10 +865,10 @@ class Expression: if ( isinstance(self.op, PrebuiltOperator) and ( - not self.result_list or + not expr_val_seq_list or ( - len(self.result_list) == 1 and - len(self.result_list[0].value_list) == 1 + len(expr_val_seq_list) == 1 and + len(expr_val_seq_list[0].expr_val_list) == 1 ) ) ): @@ -998,7 +901,7 @@ class Expression: script += '\n' script += make_comment('{id}{src_loc}'.format( - id = list(self.get_id(with_tags=False, full_qual=False))[0], + id = self.get_id(with_tags=False, full_qual=False), src_loc = '\n' + src_loc if src_loc else '' ), idt_str) @@ -1011,54 +914,54 @@ class Expression: ) # Dump the serialized value - for result in self.result_list: + for expr_val_seq in expr_val_seq_list: # Make a copy to allow modifying the parameter names - param_expr_val_map = copy.copy(result.param_expr_val_map) - value_list = result.value_list + param_map = copy.copy(expr_val_seq.param_map) + expr_val_list = expr_val_seq.expr_val_list - # Restrict the list of ExprValue we are considering to the ones + # Restrict the list of ExprVal we are considering to the ones # we were asked about - value_list = [ - expr_val for expr_val in value_list + expr_val_list = [ + expr_val for expr_val in expr_val_list if expr_val in expr_val_set ] # Filter out values where nothing was computed and there was # no exception at this step either - value_list = [ - expr_val for expr_val in value_list + expr_val_list = [ + expr_val for expr_val in expr_val_list if ( (expr_val.value is not NoValue) or (expr_val.excep is not NoValue) ) ] - if not value_list: + if not expr_val_list: continue # Rename "self" parameter to the name of the variable we are # going to apply the method on if self.op.is_method: - first_param = take_first(param_expr_val_map) - param_expr_val = param_expr_val_map.pop(first_param) + first_param = utils.take_first(param_map) + param_expr_val = param_map.pop(first_param) self_param = make_var(make_method_self_name(param_expr_val.expr)) - param_expr_val_map[self_param] = param_expr_val + param_map[self_param] = param_expr_val # Multiple values to loop over try: if is_genfunc: serialized_list = '\n' + idt.style + ('\n' + idt.style).join( format_expr_val(expr_val, lambda x: ', # ' + x) - for expr_val in value_list + for expr_val in expr_val_list ) + '\n' serialized_instance = 'for {outname} in ({values}):'.format( outname = outname, values = serialized_list ) # Just one value - elif value_list: + elif expr_val_list: serialized_instance = '{outname} = {value}'.format( outname = outname, - value = format_expr_val(value_list[0]) + value = format_expr_val(expr_val_list[0]) ) # The values cannot be serialized so we hide them except utils.NotSerializableError: @@ -1076,9 +979,9 @@ class Expression: script += make_comment(serialized_instance, idt_str) # Show the origin of the values we have shown - if param_expr_val_map: + if param_map: origin = 'Built using:' + format_build_param( - param_expr_val_map + param_map ) + '\n' script += make_comment(origin, idt_str) @@ -1102,56 +1005,75 @@ class Expression: return outname, script - @classmethod - def get_executor_map(cls, expr_wrapper_list): - # Pool of deduplicated Expression - expr_set = set() - # Prepare all the wrapped Expression for execution, so they can be - # deduplicated before being run - for expr_wrapper in expr_wrapper_list: - # The wrapped Expression could be deduplicated so we update it - expr_wrapper.expr = expr_wrapper.expr._prepare_exec(expr_set) +class ComputableExpression(ExpressionBase): + def __init__(self, op, param_map, data=None): + self.uuid = utils.create_uuid() + self.expr_val_seq_list = list() + self.data = data if data is not None else ExprData() + super().__init__(op=op, param_map=param_map) - return { - expr_wrapper: expr_wrapper.expr.execute - for expr_wrapper in expr_wrapper_list + @classmethod + def from_expr(cls, expr, **kwargs): + param_map = { + param: cls.from_expr(param_expr) + for param, param_expr in expr.param_map.items() } + return cls( + op=expr.op, + param_map=param_map, + **kwargs, + ) @classmethod - def execute_all(cls, expr_wrapper_list, *args, **kwargs): - executor_map = cls.get_executor_map(expr_wrapper_list) + def from_expr_list(cls, expr_list): + # Apply Common Subexpression Elimination to ExpressionBase before they + # are run, and then get a bound reference of "execute" that can be + # readily iterated over to get the results. + return cls.cse( + cls.from_expr(expr) + for expr in expr_list + ) - for expr_wrapper, executor in executor_map.items(): - for expr_val in executor(*args, **kwargs): - yield (expr_wrapper, expr_val) + def _get_script(self, *args, **kwargs): + return super()._get_script(*args, **kwargs, + expr_val_seq_list=self.expr_val_seq_list + ) - def _prepare_exec(self, expr_set): - """Apply a flavor of common subexpressions elimination to the Expression - graph and cleanup results of previous runs. + def get_id(self, mark_excep=False, marked_expr_val_set=set(), **kwargs): + # Mark all the values that failed to be computed because of an + # exception + marked_expr_val_set = self.get_excep() if mark_excep else marked_expr_val_set - :return: return an updated copy of the Expression it is called on - """ - # Make a copy so we don't modify the original Expression - new_expr = copy.copy(self) - new_expr.discard_result() + return super().get_id( + marked_expr_val_set=marked_expr_val_set, + **kwargs + ) - for param, param_expr in list(new_expr.param_map.items()): - # Update the param map in case param_expr was deduplicated - new_expr.param_map[param] = param_expr._prepare_exec(expr_set) + def find_expr_val_seq_list(self, param_map): + def value_map(param_map): + return ExprValParamMap( + # Extract the actual value from ExprVal + (param, expr_val.value) + for param, expr_val in param_map.items() + ) + param_map = value_map(param_map) - # Look for an existing Expression that has the same parameters so we - # don't add duplicates. - for replacement_expr in expr_set - {new_expr}: - if ( - new_expr.op.callable_ is replacement_expr.op.callable_ and - new_expr.param_map == replacement_expr.param_map - ): - return replacement_expr + # Find the results that are matching the param_map + return [ + expr_val_seq + for expr_val_seq in self.expr_val_seq_list + # Check if param_map is a subset of the param_map + # of the ExprVal. That allows checking for reusable parameters + # only. + if param_map.items() <= value_map(expr_val_seq.param_map).items() + ] - # Otherwise register this Expression so no other duplicate will be used - expr_set.add(new_expr) - return new_expr + @classmethod + def execute_all(cls, expr_list, *args, **kwargs): + for comp_expr in cls.from_expr_list(expr_list): + for expr_val in comp_expr.execute(*args, **kwargs): + yield (comp_expr, expr_val) def execute(self, post_compute_cb=None): return self._execute([], post_compute_cb) @@ -1164,161 +1086,391 @@ class Expression: # been produced reusable = self.op.reusable + def filter_param_exec_map(param_map, reusable): + return OrderedDict( + (param, param_expr._execute( + consumer_expr_stack=consumer_expr_stack + [self], + post_compute_cb=post_compute_cb, + )) + for param, param_expr in param_map.items() + if param_expr.op.reusable == reusable + ) + # Get all the generators for reusable parameters - reusable_param_exec_map = OrderedDict( - (param, param_expr._execute( - consumer_expr_stack=consumer_expr_stack + [self], - post_compute_cb=post_compute_cb, - )) - for param, param_expr in self.param_map.items() - if param_expr.op.reusable - ) - param_map_len = len(self.param_map) - reusable_param_map_len = len(reusable_param_exec_map) + reusable_param_exec_map = filter_param_exec_map(self.param_map, True) # Consume all the reusable parameters, since they are generators - for param_expr_val_map in consume_gen_map( - reusable_param_exec_map, product=ExprValue.expr_val_product + for param_map in ExprValParamMap.from_gen_map_product(self, reusable_param_exec_map): + # Check if some ExprVal are already available for the current + # set of reusable parameters. Non-reusable parameters are not + # considered since they would be different every time in any case. + if reusable and not param_map.is_partial(ignore_error=True): + # Check if we have already computed something for that + # Expression and that set of parameter values + expr_val_seq_list = self.find_expr_val_seq_list(param_map) + if expr_val_seq_list: + # Reusable objects should have only one ExprValSeq + # that was computed with a given param_map + assert len(expr_val_seq_list) == 1 + expr_val_seq = expr_val_seq_list[0] + yield from expr_val_seq.iter_expr_val() + continue + + # Only compute the non-reusable parameters if all the reusable one + # are available, otherwise that is pointless + if not param_map.is_partial(): + # Non-reusable parameters must be computed every time, and we + # don't take their cartesian product since we have fresh values + # for all operator calls. + + nonreusable_param_exec_map = filter_param_exec_map(self.param_map, False) + param_map.update(ExprValParamMap.from_gen_map(self, nonreusable_param_exec_map)) + + # Propagate exceptions if some parameters did not execute + # successfully. + if param_map.is_partial(): + expr_val = ExprVal(self, param_map) + expr_val_seq = ExprValSeq.from_one_expr_val( + self, expr_val, param_map, + ) + self.expr_val_seq_list.append(expr_val_seq) + yield expr_val + continue + + # If no value has been found, compute it and save the results in + # a list. + param_val_map = OrderedDict( + # Extract the actual computed values wrapped in ExprVal + (param, param_expr_val.value) + for param, param_expr_val in param_map.items() + ) + + # Consumer operator is special and we provide the value for it, + # instead of letting it computing its own value + if self.op.callable_ is Consumer: + try: + consumer = consumer_expr_stack[-2].op.callable_ + except IndexError: + consumer = None + iterated = [ (None, consumer, NoValue) ] + + elif self.op.callable_ is ExprData: + root_expr = consumer_expr_stack[0] + expr_data = root_expr.data + iterated = [ (expr_data.uuid, expr_data, NoValue) ] + + # Otherwise, we just call the operators with its parameters + else: + iterated = self.op.generator_wrapper(**param_val_map) + + iterator = iter(iterated) + expr_val_seq = ExprValSeq( + self, iterator, param_map, + post_compute_cb + ) + self.expr_val_seq_list.append(expr_val_seq) + yield from expr_val_seq.iter_expr_val() + + def get_all_vals(self): + return utils.flatten_seq( + expr_val_seq.expr_val_list + for expr_val_seq in self.expr_val_seq_list + ) + + def get_excep(self): + return set(utils.flatten_seq( + expr_val.get_excep() + for expr_val in self.get_all_vals() + )) + +class ClassContext: + def __init__(self, op_map, cls_map): + self.op_map = op_map + self.cls_map = cls_map + + @staticmethod + def _build_cls_map(op_set, compat_cls): + # Pool of classes that can be produced by the ops + produced_pool = set(op.value_type for op in op_set) + + # Set of all types that can be depended upon. All base class of types that + # are actually produced are also part of this set, since they can be + # dependended upon as well. + cls_set = set() + for produced in produced_pool: + cls_set.update(utils.get_mro(produced)) + cls_set.discard(object) + cls_set.discard(type(None)) + + # Map all types to the subclasses that can be used when the type is + # requested. + return { + # Make sure the list is deduplicated by building a set first + cls: sorted({ + subcls for subcls in produced_pool + if compat_cls(subcls, cls) + }, key=lambda cls: cls.__qualname__) + for cls in cls_set + } + + # Map of all produced types to a set of what operator can create them + @staticmethod + def _build_op_map(op_set, cls_map, forbidden_pattern_set): + # Make sure that the provided PrebuiltOperator will be the only ones used + # to provide their types + only_prebuilt_cls = set(itertools.chain.from_iterable( + # Augment the list of classes that can only be provided by a prebuilt + # Operator with all the compatible classes + cls_map[op.obj_type] + for op in op_set + if isinstance(op, PrebuiltOperator) + )) + + op_map = dict() + for op in op_set: + param_map, produced = op.get_prototype() + is_prebuilt_op = isinstance(op, PrebuiltOperator) + if ( + (is_prebuilt_op or produced not in only_prebuilt_cls) + and not utils.match_base_cls(produced, forbidden_pattern_set) ): - # If some parameters could not be computed, we will not get all - # values - reusable_param_computed = ( - len(param_expr_val_map) == reusable_param_map_len + op_map.setdefault(produced, set()).add(op) + return op_map + + @staticmethod + def _restrict_op_map(op_map, cls_map, restricted_pattern_set): + cls_map = copy.copy(cls_map) + + # Restrict the production of some types to a set of operators. + restricted_op_set = { + # Make sure that we only use what is available + op for op in itertools.chain.from_iterable(op_map.values()) + if utils.match_name(op.get_name(full_qual=True), restricted_pattern_set) + } + def apply_restrict(produced, op_set, restricted_op_set, cls_map): + restricted_op_set = { + op for op in restricted_op_set + if op.value_type is produced + } + if restricted_op_set: + # Make sure there is no other compatible type, so the only + # operators that will be used to satisfy that dependency will + # be one of the restricted_op_set item. + cls_map[produced] = [produced] + return restricted_op_set + else: + return op_set + op_map = { + produced: apply_restrict(produced, op_set, restricted_op_set, cls_map) + for produced, op_set in op_map.items() + } + + return (op_map, cls_map) + + @classmethod + def from_op_set(cls, op_set, forbidden_pattern_set=set(), restricted_pattern_set=set(), compat_cls=issubclass): + # Build the mapping of compatible classes + cls_map = cls._build_cls_map(op_set, compat_cls) + # Build the mapping of classes to producing operators + op_map = cls._build_op_map(op_set, cls_map, forbidden_pattern_set) + op_map, cls_map = cls._restrict_op_map(op_map, cls_map, restricted_pattern_set) + + return cls( + op_map=op_map, + cls_map=cls_map + ) + + def build_expr_list(self, result_op_seq, + non_produced_handler='raise', cycle_handler='raise'): + op_map = copy.copy(self.op_map) + cls_map = { + cls: compat_cls_set + for cls, compat_cls_set in self.cls_map.items() + # If there is at least one compatible subclass that is produced, we + # keep it, otherwise it will mislead _build_expr into thinking the + # class can be built where in fact it cannot + if compat_cls_set & op_map.keys() + } + internal_cls_set = {Consumer, ExprData} + for internal_cls in internal_cls_set: + op_map[internal_cls] = { + Operator(internal_cls, non_reusable_type_set=internal_cls_set) + } + cls_map[internal_cls] = [internal_cls] + + expr_list = list() + for result_op in result_op_seq: + expr_gen = self._build_expr(result_op, op_map, cls_map, + op_stack = [], + non_produced_handler=non_produced_handler, + cycle_handler=cycle_handler, ) + for expr in expr_gen: + if expr.validate(op_map): + expr_list.append(expr) + + # Apply CSE to get a cleaner result + return Expression.cse(expr_list) + + @classmethod + def _build_expr(cls, op, op_map, cls_map, op_stack, non_produced_handler, cycle_handler): + new_op_stack = [op] + op_stack + # We detected a cyclic dependency + if op in op_stack: + if cycle_handler == 'ignore': + return + elif callable(cycle_handler): + cycle_handler(tuple(op.callable_ for op in new_op_stack)) + return + elif cycle_handler == 'raise': + raise CycleError('Cyclic dependency found: {path}'.format( + path = ' -> '.join( + op.name for op in new_op_stack + ) + )) + else: + raise ValueError('Invalid cycle_handler') + + op_stack = new_op_stack + + param_map, produced = op.get_prototype() + if param_map: + param_list, cls_list = zip(*param_map.items()) + # When no parameter is needed + else: + yield Expression(op, OrderedDict()) + return + + # Build all the possible combinations of types suitable as parameters + cls_combis = [cls_map.get(cls, list()) for cls in cls_list] + + # Only keep the classes for "self" on which the method can be applied + if op.is_method: + cls_combis[0] = [ + cls for cls in cls_combis[0] + # If the method with the same name would resolve to "op", then + # we keep this class as a candidate for "self", otherwise we + # discard it + if getattr(cls, op.callable_.__name__, None) is op.callable_ + ] + + # Check that some produced classes are available for every parameter + ignored_indices = set() + for param, wanted_cls, available_cls in zip(param_list, cls_list, cls_combis): + if not available_cls: + # If that was an optional parameter, just ignore it without + # throwing an exception since it has a default value + if param in op.optional_param: + ignored_indices.add(param_list.index(param)) + else: + if non_produced_handler == 'ignore': + return + elif callable(non_produced_handler): + non_produced_handler(wanted_cls.__qualname__, op.name, param, + tuple(op.resolved_callable for op in op_stack) + ) + return + elif non_produced_handler == 'raise': + raise NoOperatorError('No operator can produce instances of {cls} needed for {op} (parameter "{param}" along path {path})'.format( + cls = wanted_cls.__qualname__, + op = op.name, + param = param, + path = ' -> '.join( + op.name for op in op_stack + ) + )) + else: + raise ValueError('Invalid non_produced_handler') + + param_list = utils.remove_indices(param_list, ignored_indices) + cls_combis = utils.remove_indices(cls_combis, ignored_indices) + + param_list_len = len(param_list) + + # For all possible combinations of types + for cls_combi in itertools.product(*cls_combis): + cls_combi = list(cls_combi) + + # Some classes may not be produced, but another combination + # with containing a subclass of it may actually be produced so we can + # just ignore that one. + op_combis = [ + op_map[cls] for cls in cls_combi + if cls in op_map + ] - # Check if some ExprValue are already available for the current - # set of reusable parameters. Non-reusable parameters are not - # considered since they would be different every time in any case. - if reusable and reusable_param_computed: - # Check if we have already computed something for that - # Expression and that set of parameter values - result_list = self.find_result_list(param_expr_val_map) - if result_list: - # Reusable objects should have only one ExprValueSeq - # that was computed with a given param_expr_val_map - assert len(result_list) == 1 - expr_val_seq = result_list[0] - yield from expr_val_seq.iter_expr_val() - continue + # Build all the possible combinations of operators returning these + # types + for op_combi in itertools.product(*op_combis): + op_combi = list(op_combi) - # Only compute the non-reusable parameters if all the reusable one - # are available, otherwise that is pointless - if ( - reusable_param_computed and - not any_value_is_NoValue(param_expr_val_map.values()) - ): - # Non-reusable parameters must be computed every time, and we - # don't take their cartesian product since we have fresh values - # for all operator calls. - nonreusable_param_exec_map = OrderedDict( - (param, param_expr._execute( - consumer_expr_stack=consumer_expr_stack + [self], - post_compute_cb=post_compute_cb, - )) - for param, param_expr in self.param_map.items() - if not param_expr.op.reusable - ) - param_expr_val_map.update(next( - consume_gen_map(nonreusable_param_exec_map, product=no_product) + # Get all the possible ways of calling these operators + param_combis = itertools.product(*(cls._build_expr( + param_op, op_map, cls_map, + op_stack, non_produced_handler, cycle_handler, + ) for param_op in op_combi )) - # Propagate exceptions if some parameters did not execute - # successfully. - if ( - # Some arguments are missing: there was no attempt to compute - # them because another argument failed to be computed - len(param_expr_val_map) != param_map_len or - # Or one of the arguments could not be computed - any_value_is_NoValue(param_expr_val_map.values()) - ): - expr_val = ExprValue(self, param_expr_val_map) - expr_val_seq = ExprValueSeq.from_one_expr_val( - self, expr_val, param_expr_val_map, - post_compute_cb=post_compute_cb, - ) - self.result_list.append(expr_val_seq) - yield expr_val - continue + for param_combi in param_combis: + param_map = OrderedDict(zip(param_list, param_combi)) - # If no value has been found, compute it and save the results in - # a list. - param_val_map = OrderedDict( - # Extract the actual computed values wrapped in ExprValue - (param, param_expr_val.value) - for param, param_expr_val in param_expr_val_map.items() - ) + # If all parameters can be built, carry on + if len(param_map) == param_list_len: + yield Expression(op, param_map) - # Consumer operator is special and we provide the value for it, - # instead of letting it computing its own value - if self.op.callable_ is Consumer: - try: - consumer = consumer_expr_stack[-2].op.callable_ - except IndexError: - consumer = None - iterated = [ ((consumer, None), (NoValue, None)) ] +class Expression(ExpressionBase): + def validate(self, op_map): + type_map, valid = self._get_type_map() + if not valid: + return False - elif self.op.callable_ is ExprData: - root_expr = consumer_expr_stack[0] - iterated = [ ((root_expr.data, root_expr.data_uuid), (NoValue, None)) ] + # Check that the Expression does not involve 2 classes that are + # compatible + cls_bags = [set(cls_list) for cls_list in op_map.values()] + cls_used = set(type_map.keys()) + for cls1, cls2 in itertools.product(cls_used, repeat=2): + for cls_bag in cls_bags: + if cls1 in cls_bag and cls2 in cls_bag: + return False - # Otherwise, we just call the operators with its parameters - else: - iterated = self.op.generator_wrapper(**param_val_map) + return True - iterator = iter(iterated) - expr_val_seq = ExprValueSeq( - self, iterator, param_expr_val_map, - post_compute_cb - ) - self.result_list.append(expr_val_seq) - yield from expr_val_seq.iter_expr_val() + def _get_type_map(self): + type_map = dict() + return (type_map, self._populate_type_map(type_map)) + + def _populate_type_map(self, type_map): + value_type = self.op.value_type + # If there was already an Expression producing that type, the Expression + # is not valid + found_callable = type_map.get(value_type) + if found_callable is not None and found_callable is not self.op.callable_: + return False + type_map[value_type] = self.op.callable_ -def infinite_iter(generator, value_list, from_gen): - """Exhaust the `generator` when `from_gen=True`, yield from `value_list` - otherwise. - """ - if from_gen: - for value in generator: - value_list.append(value) - yield value - else: - yield from value_list - -def no_product(*gen_list): - # Take only one value from each generator, since non-reusable - # operators are not supposed to produce more than one value. - yield [next(generator) for generator in gen_list] - -def consume_gen_map(param_map, product): - """ - :param product: Function implementing a the same interface as - :func:`itertools.product` - """ - if not param_map: - yield OrderedDict() - else: - # sort to make sure we always compute the parameters in the same order - param_list, gen_list = zip(*param_map.items()) - for values in product(*gen_list): - yield OrderedDict(zip(param_list, values)) + for param_expr in self.param_map.values(): + if not param_expr._populate_type_map(type_map): + return False + return True class AnnotationError(Exception): pass +class ForcedParamType: + pass + class Operator: def __init__(self, callable_, non_reusable_type_set=None, tags_getter=None): if non_reusable_type_set is None: non_reusable_type_set = set() if not tags_getter: - tags_getter = lambda v: [] + tags_getter = lambda v: {} self.tags_getter = tags_getter assert callable(callable_) self.callable_ = callable_ - self.signature = inspect.signature(self.resolved_callable) - self.callable_globals = self.resolved_callable.__globals__ self.annotations = copy.copy(self.resolved_callable.__annotations__) self.ignored_param = { @@ -1356,32 +1508,52 @@ class Operator: # easily. self.annotations['return'] = self.resolved_callable.__self__ + @property + def callable_globals(self): + return self.resolved_callable.__globals__ + + @property + def signature(self): + return inspect.signature(self.resolved_callable) + def __repr__(self): return '' def force_param(self, param_callable_map, tags_getter=None): - def define_type(param_type): - class ForcedType(param_type): - pass - - # Make it transparent for better reporting - ForcedType.__qualname__ = param_type.__qualname__ - ForcedType.__name__ = param_type.__name__ - ForcedType.__module__ = param_type.__module__ - return ForcedType - prebuilt_op_set = set() for param, value_list in param_callable_map.items(): - # We just get the type of the first item in the list, which should - # work in most cases - param_type = type(take_first(value_list)) + # Get the most derived class that is in common between all + # instances + value_type = utils.get_common_base(type(v) for v in value_list) + + try: + param_annot = self.annotations[param] + except KeyError: + pass + else: + # If there was an annotation, make sure the type we computed is + # compatible with what the annotation specifies. + assert issubclass(value_type, param_annot) + + # We do not inherit from value_type, since it may not always work, + # e.g. subclassing bool is forbidden. Therefore, it is purely used + # as a unique marker. + class ParamType(ForcedParamType): + pass + + # References to this type won't be serializable with pickle, but + # instances will be. This is because pickle checks that only one + # type exists with a given __module__ and __qualname__. + ParamType.__name__ = value_type.__name__ + ParamType.__qualname__ = value_type.__qualname__ + ParamType.__module__ = value_type.__module__ # Create an artificial new type that will only be produced by # the PrebuiltOperator - ForcedType = define_type(param_type) - self.annotations[param] = ForcedType + self.annotations[param] = ParamType + prebuilt_op_set.add( - PrebuiltOperator(ForcedType, value_list, + PrebuiltOperator(ParamType, value_list, tags_getter=tags_getter )) @@ -1409,13 +1581,30 @@ class Operator: except AttributeError: return None - def get_id(self, full_qual=True, qual=True): - # Factory classmethods are replaced by the class name when not - # asking for a qualified ID - if not qual and self.is_factory_cls_method: - return utils.get_name(self.value_type, full_qual=full_qual, qual=qual) + def get_id(self, full_qual=True, qual=True, style=None): + if style == 'rst': + if self.is_factory_cls_method: + qualname = utils.get_name(self.value_type, full_qual=True) + else: + qualname = self.get_name(full_qual=True) + name = self.get_id(full_qual=full_qual, qual=qual, style=None) + + if self.is_class: + role = 'class' + elif self.is_method or self.is_static_method or self.is_cls_method: + role = 'meth' + else: + role = 'func' + + return ':{role}:`{name}<{qualname}>`'.format(role=role, name=name, qualname=qualname) + else: - return self.get_name(full_qual=full_qual, qual=qual) + # Factory classmethods are replaced by the class name when not + # asking for a qualified ID + if not (qual or full_qual) and self.is_factory_cls_method: + return utils.get_name(self.value_type, full_qual=full_qual, qual=qual) + else: + return self.get_name(full_qual=full_qual, qual=qual) @property def name(self): @@ -1515,30 +1704,31 @@ class Operator: has_yielded = False for res in self.callable_(*args, **kwargs): has_yielded = True - yield (res, utils.create_uuid()), (NoValue, None) + yield (utils.create_uuid(), res, NoValue) # If no value at all were produced, we still need to yield # something if not has_yielded: - yield (NoValue, None), (NoValue, None) + yield (utils.create_uuid(), NoValue, NoValue) except Exception as e: - yield (NoValue, None), (e, utils.create_uuid()) + yield (utils.create_uuid(), NoValue, e) else: @functools.wraps(self.callable_) def genf(*args, **kwargs): + uuid_ = utils.create_uuid() # yield one value and then return try: val = self.callable_(*args, **kwargs) - yield (val, utils.create_uuid()), (NoValue, None) + yield (uuid_, val, NoValue) except Exception as e: - yield (NoValue, None), (e, utils.create_uuid()) + yield (uuid_, NoValue, e) return genf def get_prototype(self): sig = self.signature - first_param = take_first(sig.parameters) + first_param = utils.take_first(sig.parameters) annotation_map = utils.resolve_annotations(self.annotations, self.callable_globals) extra_ignored_param = set() @@ -1560,7 +1750,13 @@ class Operator: cls_name = self.resolved_callable.__qualname__.split('.')[0] self.annotations[first_param] = cls_name - produced = annotation_map['return'] + # No return annotation is accepted and is equivalent to None return + # annotation + produced = annotation_map.get('return') + # "None" annotation is accepted, even though it is not a type + # strictly speaking + if produced is None: + produced = type(None) # Recompute after potentially modifying the annotations annotation_map = utils.resolve_annotations(self.annotations, self.callable_globals) @@ -1607,8 +1803,8 @@ class PrebuiltOperator(Operator): for obj in obj_list: # Transparently copy the UUID to avoid having multiple UUIDs # refering to the same actual value. - if isinstance(obj, SerializableExprValue): - uuid_ = obj.value_uuid + if isinstance(obj, FrozenExprVal): + uuid_ = obj.uuid obj = obj.value else: uuid_ = utils.create_uuid() @@ -1629,7 +1825,7 @@ class PrebuiltOperator(Operator): def get_name(self, *args, **kwargs): return None - def get_id(self, *args, **kwargs): + def get_id(self, *args, style=None, **kwargs): return self._id or utils.get_name(self.obj_type, *args, **kwargs) @property @@ -1647,32 +1843,34 @@ class PrebuiltOperator(Operator): @property def generator_wrapper(self): def genf(): - for obj, uuid_ in zip(self.obj_list, self.uuid_list): - yield (obj, uuid_), (NoValue, None) + yield from zip(self.uuid_list, self.obj_list, itertools.repeat(NoValue)) return genf -class ExprValueSeq: - def __init__(self, expr, iterator, param_expr_val_map, post_compute_cb=None): +class ExprValSeq: + def __init__(self, expr, iterator, param_map, post_compute_cb=None): self.expr = expr assert isinstance(iterator, collections.abc.Iterator) self.iterator = iterator - self.value_list = [] - self.param_expr_val_map = param_expr_val_map + self.expr_val_list = [] + self.param_map = param_map self.post_compute_cb = post_compute_cb - @classmethod - def from_one_expr_val(cls, expr, expr_val, param_expr_val_map, post_compute_cb=None): - iterated = [( - (expr_val.value, expr_val.value_uuid), - (expr_val.excep, expr_val.excep_uuid), - )] + def from_one_expr_val(cls, expr, expr_val, param_map): + iterated = [ + (expr_val.uuid, expr_val.value, expr_val.excep) + ] new = cls( expr=expr, iterator=iter(iterated), - param_expr_val_map=param_expr_val_map, - post_compute_cb=post_compute_cb + param_map=param_map, + # no post_compute_cb, since we are not really going to compute + # anything + post_compute_cb=None, ) + # consume the iterator to make sure new.expr_val_list is updated + for _ in new.iter_expr_val(): + pass return new def iter_expr_val(self): @@ -1686,226 +1884,379 @@ class ExprValueSeq: yield x # Yield existing values - yield from yielder(self.value_list, True) + yield from yielder(self.expr_val_list, True) # Then compute the remaining ones if self.iterator: - for (value, value_uuid), (excep, excep_uuid) in self.iterator: - expr_val = ExprValue(self.expr, self.param_expr_val_map, - value, value_uuid, - excep, excep_uuid + for uuid_, value, excep in self.iterator: + expr_val = ExprVal( + expr=self.expr, + param_map=self.param_map, + value=value, + excep=excep, + uuid=uuid_, ) callback(expr_val, reused=False) - self.value_list.append(expr_val) - value_list_len = len(self.value_list) + self.expr_val_list.append(expr_val) + expr_val_list_len = len(self.expr_val_list) yield expr_val - # If value_list length has changed, catch up with the values + # If expr_val_list length has changed, catch up with the values # that were computed behind our back, so that this generator is # reentrant. - if value_list_len != len(self.value_list): + if expr_val_list_len != len(self.expr_val_list): # This will yield all values, even if the list grows while # we are yielding the control back to another piece of code. yield from yielder( - self.value_list[value_list_len:], + self.expr_val_list[expr_val_list_len:], True ) self.iterator = None -def any_value_is_NoValue(value_list): - return any( - expr_val.value is NoValue - for expr_val in value_list - ) -class SerializableExprValue: - def __init__(self, expr_val, serialized_map, hidden_callable_set=None): - self.value = expr_val.value if utils.is_serializable(expr_val.value) else NoValue - self.excep = expr_val.excep if utils.is_serializable(expr_val.excep) else NoValue +class ExprValParamMap(OrderedDict): + def is_partial(self, ignore_error=False): + def is_partial(expr_val): + # Some arguments are missing: there was no attempt to compute + # them because another argument failed to be computed + if isinstance(expr_val, UnEvaluatedExprVal): + return True - self.value_uuid = expr_val.value_uuid - self.excep_uuid = expr_val.excep_uuid + # Or computation did take place but failed + if expr_val.value is NoValue and not ignore_error: + return True - self.callable_qual_name = expr_val.expr.op.get_name(full_qual=True) - self.callable_name = expr_val.expr.op.get_name(full_qual=False, qual=False) + return False - # Pre-compute all the IDs so they are readily available once the value - # is deserialized - self.recorded_id_map = dict() - for full_qual, qual, with_tags in itertools.product((True, False), repeat=3): - self.recorded_id_map[(full_qual, qual, with_tags)] = expr_val.get_id( - full_qual=full_qual, - qual=qual, - with_tags=with_tags, - hidden_callable_set=hidden_callable_set, - ) + return any( + is_partial(expr_val) + for expr_val in self.values() + ) - self.type_names = [ - utils.get_name(type_, full_qual=True) - for type_ in utils.get_mro(expr_val.expr.op.value_type) - if type_ is not object - ] + @classmethod + def from_gen_map(cls, expr, param_gen_map): + # Pre-fill UnEvaluatedExprVal with in case we exit the loop early + param_map = cls( + (param, UnEvaluatedExprVal(expr)) + for param in param_gen_map.keys() + ) - self.param_expr_val_map = OrderedDict() - for param, param_expr_val in expr_val.param_expr_val_map.items(): - param_serialzable = param_expr_val._get_serializable( - serialized_map, - hidden_callable_set=hidden_callable_set - ) - self.param_expr_val_map[param] = param_serialzable + for param, generator in param_gen_map.items(): + val = next(generator) + # There is no point in computing values of the other generators if + # one failed to produce a useful value + if val.value is NoValue: + break + else: + param_map[param] = val - def get_id(self, full_qual=True, qual=True, with_tags=True): - args = (full_qual, qual, with_tags) - return self.recorded_id_map[args] + return param_map - def get_parent_set(self, predicate, _parent_set=None): - parent_set = set() if _parent_set is None else _parent_set - if predicate(self): - parent_set.add(self) + @classmethod + def from_gen_map_product(cls, expr, param_gen_map): + """ + Yield :class:`collections.OrderedDict` for each combination of parameter + values. - for parent in self.param_expr_val_map.values(): - parent.get_parent_set(predicate, _parent_set=parent_set) + :param param_gen_map: Mapping of parameter names to an iterator that is ready + to generate the possible values for the generator. + :type param_gen_map: collections.OrderedDict - return parent_set + """ + if not param_gen_map: + yield cls() + else: + # Since param_gen_map is an OrderedDict, we will always consume + # parameters in the same order + param_list, gen_list = zip(*param_gen_map.items()) + for values in cls._product(expr, gen_list): + yield cls(zip(param_list, values)) -class ExprValue: - def __init__(self, expr, param_expr_val_map, - value=NoValue, value_uuid=None, - excep=NoValue, excep_uuid=None, - ): + @classmethod + def _product(cls, expr, gen_list): + """ + Similar to the cartesian product provided by itertools.product, with + special handling of NoValue and some checks on the yielded sequences. + + It will only yield the combinations of values that are validated by + :meth:`validate`. + """ + def validated(generator): + """ + Ensure we only yield valid lists of :class:`ExprVal` + """ + for expr_val_list in generator: + if ExprVal.validate(expr_val_list): + yield expr_val_list + else: + continue + + def acc_product(product_generator, generator): + """ + Combine a "cartesian-product-style" generator with a plain + generator, giving a new "cartesian-product-style" generator. + """ + # We will need to use it more than once in the inner loop, so it + # has to be "restartable" (like a list, and unlike a plain + # iterator) + product_iter = utils.RestartableIter(product_generator) + for expr_val in generator: + # The value is not useful, we can return early without calling + # the other generators. That avoids spending time computing + # parameters if they won't be used anyway. + if expr_val.value is NoValue: + # Returning an incomplete list will make the calling code + # aware that some values were not computed at all + yield [expr_val] + else: + for expr_val_list in product_iter: + yield [expr_val] + expr_val_list + + def reducer(product_generator, generator): + yield from validated(acc_product(product_generator, generator)) + + def initializer(): + yield [] + + # We need to pad since we may truncate the list of values we yield if + # we detect an error in one of them. + def pad(generator, length): + for xs in generator: + xs.extend( + UnEvaluatedExprVal(expr) + for i in range(length - len(xs)) + ) + yield xs + + # reverse the gen_list so we get the rightmost generator varying the + # fastest. Typically, margins-like parameter on which we do sweeps are + # on the right side of the parameter list (to have a default value) + return pad( + functools.reduce(reducer, reversed(gen_list), initializer()), + len(gen_list) + ) + +class ExprValBase(collections.abc.Mapping): + def __init__(self, param_map, value, excep): + self.param_map = param_map self.value = value - self.value_uuid = value_uuid self.excep = excep - self.excep_uuid = excep_uuid - self.expr = expr - self.param_expr_val_map = param_expr_val_map - def format_tags(self): - tag_map = self.expr.op.tags_getter(self.value) - if tag_map: - return ''.join( - '[{}={}]'.format(k, v) if k else '[{}]'.format(val) - for k, v in sorted(tag_map.items()) - ) + def get_by_predicate(self, predicate): + return list(self._get_by_predicate(predicate)) + + def _get_by_predicate(self, predicate): + if predicate(self): + yield self + + for val in self.param_map.values(): + yield from val._get_by_predicate(predicate) + + def get_excep(self): + """ + Get all the failed parents. + """ + def predicate(val): + return val.excep is not NoValue + + return self.get_by_predicate(predicate) + + def __eq__(self, other): + return self is other + + def __hash__(self): + # consistent with definition of __eq__ + return id(self) + + def __getitem__(self, k): + if k == 'return': + return self.value else: - return '' + return self.param_map[k] - def _get_serializable(self, serialized_map, *args, **kwargs): - if serialized_map is None: - serialized_map = dict() + def __len__(self): + # account for 'return' + return len(self.param_map) + 1 - try: - return serialized_map[self] - except KeyError: - serializable = SerializableExprValue(self, serialized_map, *args, **kwargs) - serialized_map[self] = serializable - return serializable + def __iter__(self): + return itertools.chain(self.param_map.keys(), ['return']) + +class FrozenExprVal(ExprValBase): + def __init__(self, + param_map, value, excep, uuid, + callable_qualname, callable_name, recorded_id_map, + ): + self.uuid = uuid + self.callable_qualname = callable_qualname + self.callable_name = callable_name + self.recorded_id_map = recorded_id_map + super().__init__(param_map=param_map, value=value, excep=excep) + + @property + def type_names(self): + return [ + utils.get_name(type_, full_qual=True) + for type_ in utils.get_mro(type(value)) + if type_ is not object + ] @classmethod - def validate_expr_val_list(cls, expr_val_list): - if not expr_val_list: - return True + def from_expr_val(cls, expr_val, hidden_callable_set=None): + value = expr_val.value if utils.is_serializable(expr_val.value) else NoValue + excep = expr_val.excep if utils.is_serializable(expr_val.excep) else NoValue - expr_val_ref = expr_val_list[0] - expr_map_ref = expr_val_ref._get_expr_map() - - for expr_val in expr_val_list[1:]: - expr_map = expr_val._get_expr_map() - # For all Expression's that directly or indirectly lead to both the - # reference ExprValue and the ExprValue, check that it had the same - # value. That ensures that we are not making incompatible combinations. - - if not all( - expr_map_ref[expr] is expr_map[expr] - for expr - in expr_map.keys() & expr_map_ref.keys() - # We don't consider the non-reusable parameters since it is - # expected that they will differ - if expr.op.reusable - ): - return False + op = expr_val.expr.op - if not cls.validate_expr_val_list(expr_val_list[2:]): - return False + # Reloading these values will lead to issues, and they are regenerated + # for any new Expression that would be created anyway. + if op.callable_ in (ExprData, Consumer): + value = NoValue + excep = NoValue - return True + callable_qualname = op.get_name(full_qual=True) + callable_name = op.get_name(full_qual=False, qual=False) - @classmethod - def expr_val_product(cls, *gen_list): - """Similar to the cartesian product provided by itertools.product, with - special handling of NoValue and some checks on the yielded sequences. + # Pre-compute all the IDs so they are readily available once the value + # is deserialized + recorded_id_map = dict() + for full_qual, qual, with_tags in itertools.product((True, False), repeat=3): + key = cls._make_id_key( + full_qual=full_qual, + qual=qual, + with_tags=with_tags + ) + recorded_id_map[key] = expr_val.get_id( + **dict(key), + hidden_callable_set=hidden_callable_set, + ) - It will only yield the combinations of values that are validated by - :meth:`validate_expr_val_list`. - """ + param_map = ExprValParamMap( + (param, cls.from_expr_val( + param_expr_val, + hidden_callable_set=hidden_callable_set, + )) + for param, param_expr_val in expr_val.param_map.items() + ) - generator = gen_list[0] - sub_generator_list = gen_list[1:] - sub_generator_list_iterator = cls.expr_val_product(*sub_generator_list) - if sub_generator_list: - from_gen = True - value_list = list() - for expr_val in generator: - # The value is not useful, we can return early without calling the - # other generators. That avoids spending time computing parameters - # if they won't be used anyway. - if expr_val.value is NoValue: - # Returning an incomplete list will make the calling code aware - # that some values were not computed at all - yield [expr_val] - continue + froz_val = cls( + uuid=expr_val.uuid, + value=value, + excep=excep, + callable_qualname=callable_qualname, + callable_name=callable_name, + recorded_id_map=recorded_id_map, + param_map=param_map, + ) - for sub_expr_val_list in infinite_iter( - sub_generator_list_iterator, value_list, from_gen - ): - expr_val_list = [expr_val] + sub_expr_val_list - if cls.validate_expr_val_list(expr_val_list): - yield expr_val_list + return froz_val - # After the first traversal of sub_generator_list_iterator, we - # want to yield from the saved value_list - from_gen = False - else: - for expr_val in generator: - expr_val_list = [expr_val] - if cls.validate_expr_val_list(expr_val_list): - yield expr_val_list + @staticmethod + def _make_id_key(**kwargs): + return tuple(sorted(kwargs.items())) + def get_id(self, full_qual=True, qual=True, with_tags=True): + key = self._make_id_key( + full_qual=full_qual, + qual=qual, + with_tags=with_tags + ) + return self.recorded_id_map[key] - def get_id(self, *args, with_tags=True, **kwargs): - # There exists only one ID for a given ExprValue so we just return it - # instead of an iterator. - return take_first(self.expr.get_id(with_tags=with_tags, - expr_val=self, *args, **kwargs)) +class FrozenExprValSeq(collections.abc.Sequence): + def __init__(self, froz_val_list, param_map): + self.froz_val_list = froz_val_list + self.param_map = param_map - def get_parent_expr_vals(self, predicate): - yield from self._get_parent_expr_vals(predicate) + def __getitem__(self, k): + return self.froz_val_list[k] - def _get_parent_expr_vals(self, predicate, param=None): - if predicate(self, param): - yield self + def __len__(self): + return len(self.froz_val_list) + + @classmethod + def from_expr_val_seq(cls, expr_val_seq, **kwargs): + return cls( + froz_val_list=[ + FrozenExprVal.from_expr_val(expr_val, **kwargs) + for expr_val in expr_val_seq.expr_val_list + ], + param_map={ + param: FrozenExprVal.from_expr_val(expr_val, **kwargs) + for param, expr_val in expr_val_seq.param_map.items() + } + ) + + @classmethod + def from_expr_list(cls, expr_list, **kwargs): + expr_val_seq_list = utils.flatten_seq(expr.expr_val_seq_list for expr in expr_list) + return [ + cls.from_expr_val_seq(expr_val_seq, **kwargs) + for expr_val_seq in expr_val_seq_list + ] - for param, expr_val in self.param_expr_val_map.items(): - yield from expr_val._get_parent_expr_vals(predicate, param) - def get_failed_expr_vals(self): - def predicate(expr_val, param): - return expr_val.excep is not NoValue +class ExprVal(ExprValBase): + def __init__(self, expr, param_map, + value=NoValue, excep=NoValue, uuid=None, + ): + self.uuid = uuid if uuid is not None else utils.create_uuid() + self.expr = expr + super().__init__(param_map=param_map, value=value, excep=excep) - yield from self.get_parent_expr_vals(predicate) + def format_tags(self): + tag_map = self.expr.op.tags_getter(self.value) + if tag_map: + return ''.join( + '[{}={}]'.format(k, v) if k else '[{}]'.format(v) + for k, v in sorted(tag_map.items()) + ) + else: + return '' - def _get_expr_map(self): + @classmethod + def validate(cls, expr_val_list): expr_map = {} - def callback(expr_val, param): - expr_map[expr_val.expr] = expr_val + def update_map(expr_val1): + # The check does not apply for non-reusable operators, since it is + # expected that the same expression may reference multiple values + # of the same Expression. + if not expr_val1.expr.op.reusable: + return - # Consume the generator - for _ in self.get_parent_expr_vals(callback): - pass + expr_val2 = expr_map.setdefault(expr_val1.expr, expr_val1) + # Check that there is only one ExprVal per Expression, for all + # expressions that were (indirectly) involved into computation of + # expr_val_list + if expr_val2 is not expr_val1: + raise ValueError + + try: + for expr_val in expr_val_list: + # DFS traversal + expr_val.get_by_predicate(update_map) + except ValueError: + return False + else: + return True - return expr_map + def get_id(self, *args, with_tags=True, **kwargs): + return self.expr.get_id( + with_tags=with_tags, + expr_val=self, + *args, **kwargs + ) + +class UnEvaluatedExprVal(ExprVal): + def __init__(self, expr): + super().__init__( + expr=expr, + param_map=ExprValParamMap(), + uuid=None, + value=NoValue, + excep=NoValue, + ) class Consumer: def __init__(self): @@ -1913,5 +2264,6 @@ class Consumer: class ExprData(dict): def __init__(self): - pass + super().__init__() + self.uuid = utils.create_uuid() diff --git a/tools/exekall/exekall/main.py b/tools/exekall/exekall/main.py index 5f72eb9d1272e2872806ab4a7c6d819ab6bc4622..6d98fcf2231e84541c231135cc2d2e9cdf0a15b9 100755 --- a/tools/exekall/exekall/main.py +++ b/tools/exekall/exekall/main.py @@ -19,27 +19,186 @@ import argparse import collections import contextlib +import copy import datetime -import importlib +import hashlib import inspect import io import itertools import os import pathlib +import shutil import sys from exekall.customization import AdaptorBase -import exekall.engine as engine -from exekall.engine import NoValue import exekall.utils as utils -from exekall.utils import take_first, error, warn, debug, info, out +from exekall.utils import NoValue, error, warn, debug, info, out +import exekall.engine as engine + +DB_FILENAME = 'VALUE_DB.pickle.xz' + +# Create an operator for all callables that have been detected in a given +# set of modules +def build_op_set(callable_pool, non_reusable_type_set, allowed_pattern_set, adaptor): + op_set = { + engine.Operator( + callable_, + non_reusable_type_set=non_reusable_type_set, + tags_getter=adaptor.get_tags + ) + for callable_ in callable_pool + } + + filtered_op_set = adaptor.filter_op_set(op_set) + # Make sure we have all the explicitely allowed operators + filtered_op_set.update( + op for op in op_set + if utils.match_name(op.get_name(full_qual=True), allowed_pattern_set) + ) + return filtered_op_set + +def build_patch_map(sweep_spec_list, op_set): + patch_map = dict() + for sweep_spec in sweep_spec_list: + number_type = float + callable_pattern, param, start, stop, step = sweep_spec + for op in op_set: + callable_ = op.callable_ + callable_name = utils.get_name(callable_, full_qual=True) + if not utils.match_name(callable_name, [callable_pattern]): + continue + patch_map.setdefault(op, dict())[param] = [ + i for i in utils.sweep_number( + callable_, param, + number_type(start), number_type(stop), number_type(step) + ) + ] + return patch_map + +def apply_patch_map(patch_map, adaptor): + prebuilt_op_set = set() + for op, param_patch_map in patch_map.items(): + try: + new_op_set = op.force_param( + param_patch_map, + tags_getter=adaptor.get_tags + ) + prebuilt_op_set.update(new_op_set) + except KeyError as e: + error('Callable "{callable_}" has no parameter "{param}"'.format( + callable_=op.name, + param=e.args[0] + )) + continue + + return prebuilt_op_set + +def load_from_db(db, adaptor, non_reusable_type_set, pattern_list, uuid_list, uuid_args): + # We do not filter on UUID if we only got a type pattern list + load_all_uuid = ( + pattern_list and not ( + uuid_list + or uuid_args + ) + ) + + froz_val_set_set = set() + if load_all_uuid: + froz_val_set_set.update( + utils.get_froz_val_set_set(db, None, pattern_list) + ) + elif uuid_list: + froz_val_set_set.update( + utils.get_froz_val_set_set(db, uuid_list, + pattern_list + )) + elif uuid_args: + # Get the froz_val value we are interested in + froz_val_list = utils.flatten_seq( + utils.get_froz_val_set_set(db, [uuid_args], + pattern_list + )) + for froz_val in froz_val_list: + # Reload the whole context, except froz_val itself since we + # only want its arguments. We load the "indirect" arguments as + # well to ensure references to their types will be fulfilled by + # them instead of computing new values. + froz_val_set_set.add(frozenset(froz_val.get_by_predicate( + lambda v: v is not froz_val and v.value is not NoValue + ))) + + # Otherwise, reload all the root froz_val values + else: + froz_val_set_set.update( + frozenset(froz_val_seq) + for froz_val_seq in db.froz_val_seq_list + ) + + prebuilt_op_set = set() + + # Build the set of PrebuiltOperator that will inject the loaded values + # into the tests + for froz_val_set in froz_val_set_set: + froz_val_list = [ + froz_val for froz_val in froz_val_set + if froz_val.value not in (NoValue, None) + ] + if not froz_val_list: + continue + + def key(froz_val): + # Since no two sub-expression is allowed to compute values of a + # given type, it is safe to assume that grouping by the + # non-tagged ID will group together all values of compatible + # types into one PrebuiltOperator per root Expression. + return froz_val.get_id(full_qual=True, with_tags=False) + + for full_id, group in itertools.groupby(froz_val_list, key=key): + froz_val_list = list(group) + + type_ = utils.get_common_base( + type(froz_val.value) + for froz_val in froz_val_list + ) + + # Do not reload non-reusable objects, since that would lead to an + # unexpected mix-up when multiple of them were used in the same + # expresion. + # Also, it would break the guarantee that they won't be used twice. + if type_ in non_reusable_type_set: + continue + + id_ = froz_val_list[0].get_id( + full_qual=False, + qual=False, + # Do not include the tags to avoid having them displayed + # twice, and to avoid wrongfully using the tag of the first + # item in the list for all items. + with_tags=False, + ) + + prebuilt_op_set.add( + engine.PrebuiltOperator( + type_, froz_val_list, id_=id_, + non_reusable_type_set=non_reusable_type_set, + tags_getter=adaptor.get_tags, + )) + + return prebuilt_op_set def _main(argv): parser = argparse.ArgumentParser(description=""" - LISA test runner +Test runner + +PATTERNS + All patterns are fnmatch pattern, following basic shell globbing syntax. + A pattern starting with "!" is used as a negative pattern. """, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('--debug', action='store_true', + help="""Show complete Python backtrace when exekall crashes.""") + subparsers = parser.add_subparsers(title='subcommands', dest='subcommand') run_parser = subparsers.add_parser('run', @@ -48,95 +207,141 @@ def _main(argv): """, formatter_class=argparse.RawTextHelpFormatter) + # It is not possible to give a default value to positional options, + # otherwise adaptor-specific options' values will be picked up as Python + # sources, and importing the modules will therefore fail with unknown files + # error. run_parser.add_argument('python_files', nargs='+', metavar='PYTHON_SRC', - help="Python modules files") - - run_parser.add_argument('--adaptor', - help="""Adaptor to use from the customization module, if there is more -than one to choose from.""") + help="""Python modules files. If passed a folder, all contained files recursively are selected. By default, the current directory is selected.""") - run_parser.add_argument('--filter', - help="""Only run the testcases with an ID matching the filter.""") - run_parser.add_argument('--restrict', action='append', + run_parser.add_argument('-s', '--select', action='append', + metavar='SELECT_PATTERN', default=[], - help="""Callable names patterns. Types produced by these callables will -only be produced by these (other callables will be excluded).""") + help="""Only run the expressions with an ID matching any of the supplied filters.""") - run_parser.add_argument('--forbid', action='append', + # Same as --select, but allows multiple patterns without needing to + # repeat the option. This is mostly available to support wrapper + # scripts, and is not recommended for direct use since it can lead to + # some parsing ambiguities. + run_parser.add_argument('--select-multiple', nargs='*', default=[], - help="""Type names patterns. Callable returning these types or any subclass will not be called.""") + help=argparse.SUPPRESS, + ) - run_parser.add_argument('--allow', action='append', - default=[], - help="""Allow using callable with a fully qualified name matching these patterns, even if they have been not selected for various reasons..""") + run_parser.add_argument('--dry-run', action='store_true', + help="""Only show the expressions that will be run without running them.""") - run_parser.add_argument('--modules-root', action='append', default=[], - help="Equivalent to setting PYTHONPATH") + # Show the list of expressions in reStructuredText format, suitable for + # inclusion in Sphinx documentation + run_parser.add_argument('--rst-list', action='store_true', + help=argparse.SUPPRESS) + + run_parser.add_argument('--log-level', default='info', + choices=('debug', 'info', 'warn', 'error', 'critical'), + help="""Change the default log level of the standard logging module.""") + + run_parser.add_argument('--verbose', '-v', action='count', default=0, + help="""More verbose output. Can be repeated for even more verbosity. This only impacts exekall output, --log-level for more global settings.""") artifact_dir_group = run_parser.add_mutually_exclusive_group() artifact_dir_group.add_argument('--artifact-root', default=os.getenv('EXEKALL_ARTIFACT_ROOT', 'artifacts'), - help="Root folder under which the artifact folders will be created") + help="Root folder under which the artifact folders will be created. Defaults to EXEKALL_ARTIFACT_ROOT env var.") artifact_dir_group.add_argument('--artifact-dir', default=os.getenv('EXEKALL_ARTIFACT_DIR'), - help="""Folder in which the artifacts will be stored. This take -precedence over --artifact-root""") + help="""Folder in which the artifacts will be stored. Defaults to EXEKALL_ARTIFACT_DIR env var.""") - run_parser.add_argument('--load-db', - help="""Reload a database and use its results as prebuilt objects.""") + run_parser.add_argument('--load-db', action='append', + default=[], + help="""Reload a database to use some of its objects. The DB and its artifact directory will be merged in the produced DB at the end of the execution, to form a self-contained artifact directory.""") - run_parser.add_argument('--load-type', action='append', default=[], - help="""Load the (indirect) instances of the given class from the -database instead of the root objects.""") + run_parser.add_argument('--load-type', action='append', + metavar='LOAD_TYPE_PATTERN', + default=[], + help="""Load the (indirect) instances of the given class from the database instead of the root objects.""") uuid_group = run_parser.add_mutually_exclusive_group() - uuid_group.add_argument('--load-uuid', action='append', default=[], - help="""Load the given UUID from the database. What is reloaded can be -refined with --load-type.""") + uuid_group.add_argument('--load-uuid', action='append', + default=[], + help="""Load the given UUID from the database.""") + + uuid_group.add_argument('--replay', + help="""Replay the execution of the given UUID, loading as much prerequisite from the DB as possible.""") + # Load the parameters that were used to compute the value with the given + # UUID from the database. This can be used as a more flexible form of + # --replay that does not imply restricting the selection uuid_group.add_argument('--load-uuid-args', - help="""Load the parameters of the values that were used to compute the -given UUID from the database.""") + help=argparse.SUPPRESS) + + run_parser.add_argument('--restrict', action='append', + metavar='RESTRICT_PATTERN', + default=[], + help="""Callable names patterns. Types produced by these callables will only be produced by these (other callables will be excluded).""") + + run_parser.add_argument('--forbid', action='append', + metavar='FORBID_PATTERN', + default=[], + help="""Fully qualified type names patterns. Callable returning these types or any subclass will not be called.""") + + run_parser.add_argument('--allow', action='append', + metavar='ALLOW_PATTERN', + default=[], + help="""Allow using callable with a fully qualified name matching these patterns, even if they have been not selected for various reasons.""") goal_group = run_parser.add_mutually_exclusive_group() goal_group.add_argument('--goal', action='append', - help="""Compute expressions leading to an instance of the specified -class or a subclass of it.""") + metavar='GOAL_PATTERN', + default=[], + help="""Compute expressions leading to an instance of a class with name matching this pattern (or a subclass of it).""") goal_group.add_argument('--callable-goal', action='append', + metavar='CALLABLE_GOAL_PATTERN', default=[], - help="""Compute expressions ending with a callable which name is -matching this pattern.""") + help="""Compute expressions ending with a callable which name is matching this pattern.""") run_parser.add_argument('--sweep', nargs=5, action='append', default=[], - metavar=('CALLABLE', 'PARAM', 'START', 'STOP', 'STEP'), - help="""Parametric sweep on a function parameter. -It needs five fields: the qualified name of the callable (pattern can be used), -the name of the parameter, the start value, stop value and step size.""") + metavar=('CALLABLE_PATTERN', 'PARAM', 'START', 'STOP', 'STEP'), + help="""Parametric sweep on a function parameter. It needs five fields: + * pattern matching qualified name of the callable + * name of the parameter + * start value + * stop value + * step size.""") - run_parser.add_argument('--verbose', '-v', action='count', default=0, - help="""More verbose output.""") + run_parser.add_argument('--template-scripts', metavar='SCRIPT_FOLDER', + help="""Only create the template scripts of the expressions without running them.""") - run_parser.add_argument('--dry-run', action='store_true', - help="""Only show the tests that will be run without running them.""") + run_parser.add_argument('--adaptor', + help="""Adaptor to use from the customization module, if there is more than one to choose from.""") - run_parser.add_argument('--template-scripts', metavar='SCRIPT_FOLDER', - help="""Only create the template scripts of the tests without running them.""") - run_parser.add_argument('--log-level', default='info', - choices=('debug', 'info', 'warn', 'error', 'critical'), - help="""Change the default log level of the standard logging module.""") + merge_parser = subparsers.add_parser('merge', + description=""" +Merge artifact directories of "exekall run" executions. - run_parser.add_argument('--debug', action='store_true', - help="""Show complete Python backtrace when exekall crashes.""") +By default, it will use hardlinks instead of copies to improve speed and avoid +eating up large amount of space, but that means that artifact directories +should be treated as read-only. + """, + formatter_class=argparse.RawTextHelpFormatter) + + merge_parser.add_argument('artifact_dirs', nargs='+', + help="""Artifact directories created using "exekall run", or value databases to merge.""") + + merge_parser.add_argument('-o', '--output', required=True, + help="""Output merged artifacts directory or value database.""") + + merge_parser.add_argument('--copy', action='store_true', + help="""Force copying files, instead of using hardlinks.""") # Avoid showing help message on the incomplete parser. Instead, we carry on - # and the help will be displayed after the parser customization has a - # chance to take place. + # and the help will be displayed after the parser customization of run + # subcommand has a chance to take place. help_options = ('-h', '--help') no_help_argv = [ arg for arg in argv @@ -151,7 +356,9 @@ the name of the parameter, the start value, stop value and step size.""") # --help for example. If it was for another reason, it will fail again and # show the message. except SystemExit: - args, _ = parser.parse_known_args(argv) + parser.parse_known_args(argv) + # That should never be reached + assert False if not args.subcommand: parser.print_help() @@ -160,8 +367,118 @@ the name of the parameter, the start value, stop value and step size.""") global show_traceback show_traceback = args.debug + # Some subcommands need not parser customization, in which case we more + # strictly parse the command line + if args.subcommand not in ['run']: + parser.parse_args(argv) + + if args.subcommand == 'run': + # do_run needs to reparse the CLI, so it needs the parser and argv + return do_run(args, parser, run_parser, argv) + + elif args.subcommand == 'merge': + return do_merge( + artifact_dirs=args.artifact_dirs, + output_dir=args.output, + use_hardlink=(not args.copy), + ) + +def do_merge(artifact_dirs, output_dir, use_hardlink=True, output_exist=False): + output_dir = pathlib.Path(output_dir) + + artifact_dirs = [pathlib.Path(path) for path in artifact_dirs] + # Dispatch folders and databases + db_path_list = [path for path in artifact_dirs if path.is_file()] + artifact_dirs = [path for path in artifact_dirs if path.is_dir()] + + # Only DB paths + if not artifact_dirs: + merged_db_path = output_dir + else: + # This will fail loudly if the folder already exists + os.makedirs(str(output_dir), exist_ok=output_exist) + merged_db_path = output_dir/DB_FILENAME + + testsession_uuid_list = [] + for artifact_dir in artifact_dirs: + with (artifact_dir/'UUID').open(encoding='utf-8') as f: + testsession_uuid = f.read().strip() + testsession_uuid_list.append(testsession_uuid) + + link_base_path = pathlib.Path('ORIGIN', testsession_uuid) + + # Copy all the files recursively + for dirpath, dirnames, filenames in os.walk(str(artifact_dir)): + dirpath = pathlib.Path(dirpath) + for name in filenames: + path = dirpath/name + rel_path = pathlib.Path(os.path.relpath(str(path), str(artifact_dir))) + link_path = output_dir/link_base_path/rel_path + + levels = pathlib.Path(*(['..'] * ( + len(rel_path.parents) + + len(link_base_path.parents) + - 1 + ))) + src_link_path = levels/rel_path + + # top-level files are relocated under a ORIGIN instead of having + # a symlink, otherwise they would clash + if dirpath == artifact_dir: + dst_path = link_path + create_link = False + # Otherwise, UUIDs will ensure that there is no clash + else: + dst_path = output_dir/rel_path + create_link = True + + # Create the folder and make sure that all its parents get the + # same stats as the original one, in order to preserve creation + # date. + os.makedirs(str(dst_path.parent), exist_ok=True) + # We do not do copystat on the topmost parent, as it is shared + # by all original artifact_dir + for parent in list(rel_path.parents)[:-2]: + stat_src = artifact_dir/parent + stat_dst = output_dir/parent + shutil.copystat(str(stat_src), str(stat_dst)) + + # Create a mirror of the original hierarchy + if create_link: + os.makedirs(str(link_path.parent), exist_ok=True) + link_path.symlink_to(src_link_path) + + if use_hardlink: + os.link(str(path), str(dst_path)) + # Preserve the original creation date + shutil.copystat(str(path), str(dst_path), follow_symlinks=False) + else: + shutil.copy2(str(path), str(dst_path)) + + if dirpath == artifact_dir and name == DB_FILENAME: + db_path_list.append(path) + + if artifact_dirs: + # Combine the origin UUIDs to have a stable UUID for the merged + # artifacts + combined_uuid = hashlib.sha256( + b'\n'.join( + uuid_.encode('ascii') + for uuid_ in sorted(testsession_uuid_list) + ) + ).hexdigest()[:32] + with (output_dir/'UUID').open('wt') as f: + f.write(combined_uuid+'\n') + + merged_db = engine.ValueDB.merge( + engine.ValueDB.from_path(path) + for path in db_path_list + ) + merged_db.to_path(merged_db_path) + +def do_run(args, parser, run_parser, argv): # Import all modules, before selecting the adaptor - module_set = {utils.import_file(path) for path in args.python_files} + module_set = utils.import_paths(args.python_files) # Look for a customization submodule in one of the parent packages of the # modules we specified on the command line. @@ -186,26 +503,37 @@ the name of the parameter, the start value, stop value and step size.""") dry_run = args.dry_run only_template_scripts = args.template_scripts - type_goal_pattern = args.goal + rst_expr_list = args.rst_list + if rst_expr_list: + dry_run = True + + type_goal_pattern_set = set(args.goal) callable_goal_pattern_set = set(args.callable_goal) - if not (type_goal_pattern or callable_goal_pattern_set): - type_goal_pattern = set(adaptor_cls.get_default_type_goal_pattern_set()) + if not (type_goal_pattern_set or callable_goal_pattern_set): + type_goal_pattern_set = set(adaptor_cls.get_default_type_goal_pattern_set()) - load_db_path = args.load_db + load_db_path_list = args.load_db load_db_pattern_list = args.load_type load_db_uuid_list = args.load_uuid - load_db_uuid_args = args.load_uuid_args + load_db_replay_uuid = args.replay + load_db_uuid_args = load_db_replay_uuid or args.load_uuid_args + + user_filter_set = set(args.select) + user_filter_set.update(args.select_multiple) + + if load_db_replay_uuid and user_filter_set: + run_parser.error('--replay and --select cannot be used at the same time') + + if load_db_replay_uuid and not load_db_path_list: + run_parser.error('--load-db must be specified to use --replay') - user_filter = args.filter restricted_pattern_set = set(args.restrict) forbidden_pattern_set = set(args.forbid) allowed_pattern_set = set(args.allow) allowed_pattern_set.update(restricted_pattern_set) allowed_pattern_set.update(callable_goal_pattern_set) - sys.path.extend(args.modules_root) - # Setup the artifact_dir so we can create a verbose log in there date = datetime.datetime.now().strftime('%Y%m%d_%H:%M:%S') testsession_uuid = utils.create_uuid() @@ -227,262 +555,70 @@ the name of the parameter, the start value, stop value and step size.""") # Update the CLI arguments so the customization module has access to the # correct value args.artifact_dir = artifact_dir - debug_log = artifact_dir.joinpath('debug_log.txt') - info_log = artifact_dir.joinpath('info_log.txt') + debug_log = artifact_dir/'DEBUG.log' + info_log = artifact_dir/'INFO.log' utils.setup_logging(args.log_level, debug_log, info_log, verbose=verbose) - non_reusable_type_set = adaptor.get_non_reusable_type_set() - - # Get the prebuilt operators from the adaptor - if not load_db_path: - prebuilt_op_pool_list = adaptor.get_prebuilt_list() - - # Load objects from an existing database - else: - db = adaptor.load_db(load_db_path) - - # We do not filter on UUID if we only got a type pattern list - load_all_uuid = ( - load_db_pattern_list and not ( - load_db_uuid_list - or load_db_uuid_args - ) - ) - - serial_res_set = set() - if load_all_uuid: - serial_res_set.update( - utils.load_serial_from_db(db, None, load_db_pattern_list) - ) - elif load_db_uuid_list: - serial_res_set.update( - utils.load_serial_from_db(db, load_db_uuid_list, - load_db_pattern_list - )) - elif load_db_uuid_args: - # Get the serial value we are interested in - serial_list = utils.flatten_nested_seq( - utils.load_serial_from_db(db, [load_db_uuid_args], - load_db_pattern_list - )) - for serial in serial_list: - # Get all the UUIDs of its parameters - param_uuid_list = [ - param_serial.value_uuid - for param_serial in serial.param_expr_val_map.values() - ] - - serial_res_set.update( - utils.load_serial_from_db(db, param_uuid_list, - load_db_pattern_list - )) - - # Otherwise, reload all the root serial values - else: - serial_res_set.update( - frozenset(l) - for l in db.obj_store.serial_seq_list - ) - - # Remove duplicates accross sets - loaded_serial = set() - serial_res_set_ = set() - for serial_res in serial_res_set: - serial_res = frozenset(serial_res - loaded_serial) - loaded_serial.update(serial_res) - if serial_res: - serial_res_set_.add(serial_res) - serial_res_set = serial_res_set_ - - # Build the list of PrebuiltOperator that will inject the loaded values - # into the tests - prebuilt_op_pool_list = list() - for serial_res in serial_res_set: - serial_list = [ - serial for serial in serial_res - if serial.value is not NoValue - ] - if not serial_list: - continue + # Get the set of all callables in the given set of modules + callable_pool = utils.get_callable_set(module_set, verbose=verbose) - def key(serial): - # Since no two sub-expression is allowed to compute values of a - # given type, it is safe to assume that grouping by the - # non-tagged ID will group together all values of compatible - # types into one PrebuiltOperator per root Expression. - return serial.get_id(full_qual=True, with_tags=False) + # Build the pool of operators from the callables + non_reusable_type_set = set(utils.flatten_seq( + utils.get_subclasses(cls) + for cls in adaptor.get_non_reusable_type_set() + )) - for full_id, group in itertools.groupby(serial_list, key=key): - serial_list = list(group) + op_set = build_op_set( + callable_pool, non_reusable_type_set, allowed_pattern_set, adaptor, + ) - type_ = type(serial_list[0].value) - id_ = serial_list[0].get_id( - full_qual=False, - qual=False, - # Do not include the tags to avoid having them displayed - # twice, and to avoid wrongfully using the tag of the first - # item in the list for all items. - with_tags=False, + # Load objects from an existing database + if load_db_path_list: + db_list = [] + for db_path in load_db_path_list: + db = adaptor.load_db(db_path) + op_set.update( + load_from_db(db, adaptor, non_reusable_type_set, + load_db_pattern_list, load_db_uuid_list, load_db_uuid_args ) - prebuilt_op_pool_list.append( - engine.PrebuiltOperator( - type_, serial_list, id_=id_, - non_reusable_type_set=non_reusable_type_set, - tags_getter=adaptor.get_tags, - )) - - # Pool of all callable considered - callable_pool = utils.get_callable_set(module_set, verbose=verbose) - op_pool = { - engine.Operator( - callable_, - non_reusable_type_set=non_reusable_type_set, - tags_getter=adaptor.get_tags - ) - for callable_ in callable_pool - } - filtered_op_pool = adaptor.filter_op_pool(op_pool) - # Make sure we have all the explicitely allowed operators - filtered_op_pool.update( - op for op in op_pool - if utils.match_name(op.get_name(full_qual=True), allowed_pattern_set) - ) - op_pool = filtered_op_pool + ) + db_list.append(db) + # Get the prebuilt operators from the adaptor + else: + db_list = [] + op_set.update(adaptor.get_prebuilt_set()) # Force some parameter values to be provided with a specific callable - patch_map = dict() - for sweep_spec in args.sweep: - number_type = float - callable_pattern, param, start, stop, step = sweep_spec - for callable_ in callable_pool: - callable_name = utils.get_name(callable_, full_qual=True) - if not utils.match_name(callable_name, [callable_pattern]): - continue - patch_map.setdefault(callable_name, dict())[param] = [ - i for i in utils.sweep_number( - callable_, param, - number_type(start), number_type(stop), number_type(step) - ) - ] + patch_map = build_patch_map(args.sweep, op_set) + op_set.update(apply_patch_map(patch_map, adaptor)) - for op_name, param_patch_map in patch_map.items(): - for op in op_pool: - if op.name == op_name: - try: - new_op_pool = op.force_param( - param_patch_map, - tags_getter=adaptor.get_tags - ) - prebuilt_op_pool_list.extend(new_op_pool) - except KeyError as e: - error('Callable "{callable_}" has no parameter "{param}"'.format( - callable_=op_name, - param=e.args[0] - )) - continue - - # Register stub PrebuiltOperator for the provided prebuilt instances - op_pool.update(prebuilt_op_pool_list) - - # Sort to have stable output - op_pool = sorted(op_pool, key=lambda x: str(x.name)) - - # Pool of classes that can be produced by the ops - produced_pool = set(op.value_type for op in op_pool) - - # Set of all types that can be depended upon. All base class of types that - # are actually produced are also part of this set, since they can be - # dependended upon as well. - cls_set = set() - for produced in produced_pool: - cls_set.update(utils.get_mro(produced)) - cls_set.discard(object) - cls_set.discard(type(None)) - - # Map all types to the subclasses that can be used when the type is - # requested. - cls_map = { - # Make sure the list is deduplicated by building a set first - cls: sorted({ - subcls for subcls in produced_pool - if issubclass(subcls, cls) - }, key=lambda cls: cls.__qualname__) - for cls in cls_set + # Some operators are hidden in IDs since they don't add useful information + # (internal classes) + hidden_callable_set = { + op.callable_ + for op in adaptor.get_hidden_op_set(op_set) } - # Make sure that the provided PrebuiltOperator will be the only ones used - # to provide their types - only_prebuilt_cls = set(itertools.chain.from_iterable( - # Augment the list of classes that can only be provided by a prebuilt - # Operator with all the compatible classes - cls_map[op.obj_type] - for op in prebuilt_op_pool_list - )) - - only_prebuilt_cls.discard(type(NoValue)) - - # Map of all produced types to a set of what operator can create them - def build_op_map(op_pool, only_prebuilt_cls, forbidden_pattern_set): - op_map = dict() - for op in op_pool: - param_map, produced = op.get_prototype() - is_prebuilt_op = isinstance(op, engine.PrebuiltOperator) - if ( - (is_prebuilt_op or produced not in only_prebuilt_cls) - and not utils.match_base_cls(produced, forbidden_pattern_set) - ): - op_map.setdefault(produced, set()).add(op) - return op_map - - op_map = build_op_map(op_pool, only_prebuilt_cls, forbidden_pattern_set) + # These get_id() options are used for all user-exposed listing that is supposed to be + # filterable with user_filter_set (like dry_run) + filterable_id_kwargs = dict( + full_qual=False, + qual=False, + with_tags=False, + hidden_callable_set=hidden_callable_set + ) - # Restrict the production of some types to a set of operators. - restricted_op_set = { - # Make sure that we only use what is available - op for op in itertools.chain.from_iterable(op_map.values()) - if utils.match_name(op.get_name(full_qual=True), restricted_pattern_set) - } - def apply_restrict(produced, op_set, restricted_op_set, cls_map): - restricted_op_set = { - op for op in restricted_op_set - if op.value_type is produced + # Restrict the Expressions that will be executed to just the one we + # care about + if db_list and load_db_replay_uuid: + id_kwargs = copy.copy(filterable_id_kwargs) + del id_kwargs['hidden_callable_set'] + # Let the merge logic handle duplicated UUIDs + db = engine.ValueDB.merge(db_list) + user_filter_set = { + db.get_by_uuid(load_db_replay_uuid).get_id(**id_kwargs) } - if restricted_op_set: - # Make sure there is no other compatible type, so the only operators - # that will be used to satisfy that dependency will be one of the - # restricted_op_set item. - cls_map[produced] = [produced] - return restricted_op_set - else: - return op_set - op_map = { - produced: apply_restrict(produced, op_set, restricted_op_set, cls_map) - for produced, op_set in op_map.items() - } - - # Get the callable goals - root_op_set = set() - if callable_goal_pattern_set: - root_op_set.update( - op for op in op_pool - if utils.match_name(op.get_name(full_qual=True), callable_goal_pattern_set) - ) - - # Get the list of root operators by produced type - if type_goal_pattern: - for produced, op_set in op_map.items(): - # All producers of the goal types can be a root operator in the - # expressions we are going to build, i.e. the outermost function call - if utils.match_base_cls(produced, type_goal_pattern): - root_op_set.update(op_set) - - # Sort for stable output - root_op_list = sorted(root_op_set, key=lambda op: str(op.name)) - - # Some operators are hidden in IDs since they don't add useful information - # (internal classes) - hidden_callable_set = adaptor.get_hidden_callable_set(op_map) # Only print once per parameters' tuple if verbose: @@ -507,116 +643,160 @@ the name of the parameter, the start value, stop value and step size.""") handle_non_produced = 'ignore' handle_cycle = 'ignore' + # Get the callable goals, either by the callable name or the value type + root_op_set = set( + op for op in op_set + if ( + utils.match_name(op.get_name(full_qual=True), callable_goal_pattern_set) + or + # All producers of the goal types can be a root operator in the + # expressions we are going to build, i.e. the outermost function call + utils.match_base_cls(op.value_type, type_goal_pattern_set) + # Only keep the Expression where the outermost (root) operator is + # defined in one of the files that were explicitely specified on the + # command line. + ) and inspect.getmodule(op.callable_) in module_set + ) + + # Build the class context from the set of Operator's that we collected + class_ctx = engine.ClassContext.from_op_set( + op_set=op_set, + forbidden_pattern_set=forbidden_pattern_set, + restricted_pattern_set=restricted_pattern_set + ) + # Build the list of Expression that can be constructed from the set of # callables - testcase_list = list(engine.ExpressionWrapper.build_expr_list( - root_op_list, op_map, cls_map, - non_produced_handler = handle_non_produced, - cycle_handler = handle_cycle, - )) + expr_list = class_ctx.build_expr_list( + root_op_set, + non_produced_handler=handle_non_produced, + cycle_handler=handle_cycle, + ) # First, sort with the fully qualified ID so we have the strongest stability # possible from one run to another - testcase_list.sort(key=lambda expr: take_first(expr.get_id(full_qual=True, with_tags=True))) + expr_list.sort(key=lambda expr: expr.get_id(full_qual=True, with_tags=True)) # Then sort again according to what will be displayed. Since it is a stable # sort, it will keep a stable order for IDs that look the same but actually # differ in their hidden part - testcase_list.sort(key=lambda expr: take_first(expr.get_id(qual=False, with_tags=True))) - - # Only keep the Expression where the outermost (root) operator is defined - # in one of the files that were explicitely specified on the command line. - testcase_list = [ - testcase - for testcase in testcase_list - if inspect.getmodule(testcase.op.callable_) in module_set - ] + expr_list.sort(key=lambda expr: expr.get_id(qual=False, with_tags=True)) - if user_filter: - testcase_list = [ - testcase for testcase in testcase_list - if utils.match_name(take_first(testcase.get_id( - # These options need to match what --dry-run gives (unless - # verbose is used) - full_qual=False, - qual=False, - hidden_callable_set=hidden_callable_set)), [user_filter]) + if user_filter_set: + expr_list = [ + expr for expr in expr_list + if utils.match_name(expr.get_id(**filterable_id_kwargs), user_filter_set) ] - if not testcase_list: - info('Nothing to do, exiting ...') - return 0 + if not expr_list: + info('Nothing to do, check --help while passing some python sources to get the full help.') + return 1 + + id_kwargs = { + **filterable_id_kwargs, + 'full_qual': bool(verbose), + } + + if rst_expr_list: + id_kwargs['style'] = 'rst' + for expr in expr_list: + out(expr.get_id(**id_kwargs)) + else: + out('The following expressions will be executed:\n') + for expr in expr_list: + out(expr.get_id(**id_kwargs)) - out('The following expressions will be executed:\n') - for testcase in testcase_list: - out(take_first(testcase.get_id( - full_qual=bool(verbose), - qual=bool(verbose), - hidden_callable_set=hidden_callable_set - ))) - if verbose >= 2: - out(testcase.pretty_structure() + '\n') + if verbose >= 2: + out(expr.get_structure() + '\n') if dry_run: return 0 + exec_ret_code = exec_expr_list( + expr_list=expr_list, + adaptor=adaptor, + artifact_dir=artifact_dir, + testsession_uuid=testsession_uuid, + hidden_callable_set=hidden_callable_set, + only_template_scripts=only_template_scripts, + verbose=verbose, + ) + + # If we reloaded a DB, merge it with the current DB so the outcome is a + # self-contained artifact dir + if load_db_path_list: + orig_list = [ + path if path.is_dir() else path.parent + for path in map(pathlib.Path, load_db_path_list) + ] + do_merge(orig_list, artifact_dir, output_exist=True) + + return exec_ret_code + +def exec_expr_list(expr_list, adaptor, artifact_dir, testsession_uuid, + hidden_callable_set, only_template_scripts, verbose): + if not only_template_scripts: - with open(str(artifact_dir.joinpath('UUID')), 'wt') as f: + with (artifact_dir/'UUID').open('wt') as f: f.write(testsession_uuid+'\n') db_loader = adaptor.load_db out('\nArtifacts dir: {}\n'.format(artifact_dir)) - # Apply the common subexpression elimination before trying to create the - # template scripts - executor_map = engine.Expression.get_executor_map(testcase_list) + # Get a list of ComputableExpression in order to execute them + expr_list = engine.ComputableExpression.from_expr_list(expr_list) - for testcase in executor_map.keys(): - testcase_short_id = take_first(testcase.get_id( + for expr in expr_list: + expr_short_id = expr.get_id( hidden_callable_set=hidden_callable_set, with_tags=False, full_qual=False, qual=False, - )) + ) - data = testcase.data - data['id'] = testcase_short_id - data['uuid'] = testcase.uuid + data = expr.data + data['id'] = expr_short_id + data['uuid'] = expr.uuid - testcase_artifact_dir = pathlib.Path( + expr_artifact_dir = pathlib.Path( artifact_dir, - testcase.op.get_name(full_qual=False), - testcase_short_id, - testcase.uuid + expr_short_id, + expr.uuid ) - testcase_artifact_dir.mkdir(parents=True) - testcase_artifact_dir = testcase_artifact_dir.resolve() + expr_artifact_dir.mkdir(parents=True) + expr_artifact_dir = expr_artifact_dir.resolve() data['artifact_dir'] = artifact_dir - data['testcase_artifact_dir'] = testcase_artifact_dir + data['expr_artifact_dir'] = expr_artifact_dir adaptor.update_expr_data(data) - with open(str(testcase_artifact_dir.joinpath('UUID')), 'wt') as f: - f.write(testcase.uuid + '\n') + with (expr_artifact_dir/'UUID').open('wt') as f: + f.write(expr.uuid + '\n') - with open(str(testcase_artifact_dir.joinpath('ID')), 'wt') as f: - f.write(testcase_short_id+'\n') + with (expr_artifact_dir/'ID').open('wt') as f: + f.write(expr_short_id+'\n') - with open(str(testcase_artifact_dir.joinpath('STRUCTURE')), 'wt') as f: - f.write(take_first(testcase.get_id( + with (expr_artifact_dir/'STRUCTURE').open('wt') as f: + f.write(expr.get_id( hidden_callable_set=hidden_callable_set, with_tags=False, full_qual=True, - )) + '\n\n') - f.write(testcase.pretty_structure()) + ) + '\n\n') + f.write(expr.get_structure() + '\n') + + is_svg, dot_output = utils.render_graphviz(expr) + graphviz_path = expr_artifact_dir/'STRUCTURE.{}'.format( + 'svg' if is_svg else 'dot' + ) + with graphviz_path.open('wt', encoding='utf-8') as f: + f.write(dot_output) - with open( - str(testcase_artifact_dir.joinpath('testcase_template.py')), + with (expr_artifact_dir/'TESTCASE_TEMPLATE.py').open( 'wt', encoding='utf-8' ) as f: f.write( - testcase.get_script( + expr.get_script( prefix = 'testcase', - db_path = '../../storage.yml.gz', + db_path = os.path.join('..', DB_FILENAME), db_relative_to = '__file__', db_loader=db_loader )[1]+'\n', @@ -625,28 +805,30 @@ the name of the parameter, the start value, stop value and step size.""") if only_template_scripts: return 0 - result_map = collections.defaultdict(list) - for testcase, executor in executor_map.items(): + # Preserve the execution order, so the summary is displayed in the same + # order + result_map = collections.OrderedDict() + for expr in expr_list: exec_start_msg = 'Executing: {short_id}\n\nID: {full_id}\nArtifacts: {folder}\nUUID: {uuid_}'.format( - short_id=take_first(testcase.get_id( + short_id=expr.get_id( hidden_callable_set=hidden_callable_set, full_qual=False, qual=False, - )), + ), - full_id=take_first(testcase.get_id( + full_id=expr.get_id( hidden_callable_set=hidden_callable_set if not verbose else None, full_qual=True, - )), - folder=testcase.data['testcase_artifact_dir'], - uuid_=testcase.uuid + ), + folder=expr.data['expr_artifact_dir'], + uuid_=expr.uuid ).replace('\n', '\n# ') delim = '#' * (len(exec_start_msg.splitlines()[0]) + 2) out(delim + '\n# ' + exec_start_msg + '\n' + delim) result_list = list() - result_map[testcase] = result_list + result_map[expr] = result_list def pre_line(): out('-' * 40) @@ -658,40 +840,48 @@ the name of the parameter, the start value, stop value and step size.""") sys.stderr.flush() def get_uuid_str(expr_val): - uuid_val = (expr_val.value_uuid or expr_val.excep_uuid) - if uuid_val: - return ' UUID={}'.format(uuid_val) - else: - return '' + return 'UUID={}'.format(expr_val.uuid) + computed_expr_val_set = set() + reused_expr_val_set = set() def log_expr_val(expr_val, reused): - if expr_val.expr.op.callable_ in hidden_callable_set: - return + # Consider that PrebuiltOperator reuse values instead of actually + # computing them. + if isinstance(expr_val.expr.op, engine.PrebuiltOperator): + reused = True if reused: - msg='Reusing already computed {id}{uuid}' + msg = 'Reusing already computed {id} {uuid}' + reused_expr_val_set.add(expr_val) else: - msg='Computed {id}{uuid}' + msg = 'Computed {id} {uuid}' + computed_expr_val_set.add(expr_val) - info(msg.format( - id=expr_val.get_id( - full_qual=False, - with_tags=True, - hidden_callable_set=hidden_callable_set, - ), - uuid = get_uuid_str(expr_val), - )) + op = expr_val.expr.op + if ( + op.callable_ not in hidden_callable_set + and not issubclass(op.value_type, engine.ForcedParamType) + ): + info(msg.format( + id=expr_val.get_id( + full_qual=False, + with_tags=True, + hidden_callable_set=hidden_callable_set, + ), + uuid=get_uuid_str(expr_val), + )) - executor = executor(log_expr_val) + # This returns an iterator + executor = expr.execute(log_expr_val) out('') for result in utils.iterate_cb(executor, pre_line, flush_std_streams): - for failed_val in result.get_failed_expr_vals(): - excep = failed_val.excep + for excep_val in result.get_excep(): + excep = excep_val.excep tb = utils.format_exception(excep) - error('Error ({e_name}): {e}\nID: {id}\n{tb}'.format( - id=failed_val.get_id(), - e_name = utils.get_name(type(excep)), + error('{e_name}: {e}\nID: {id}\n{tb}'.format( + id=excep_val.get_id(), + e_name=utils.get_name(type(excep)), e=excep, tb=tb, ), @@ -714,62 +904,70 @@ the name of the parameter, the start value, stop value and step size.""") out('') - testcase_artifact_dir = testcase.data['testcase_artifact_dir'] + expr_artifact_dir = expr.data['expr_artifact_dir'] # Finalize the computation - adaptor.finalize_expr(testcase) + adaptor.finalize_expr(expr) # Dump the reproducer script - with open( - str(testcase_artifact_dir.joinpath('testcase.py')), - 'wt', encoding='utf-8' - ) as f: + with (expr_artifact_dir/'TESTCASE.py').open('wt', encoding='utf-8') as f: f.write( - testcase.get_script( + expr.get_script( prefix = 'testcase', - db_path = '../../../storage.yml.gz', + db_path = os.path.join('..', '..', DB_FILENAME), db_relative_to = '__file__', db_loader=db_loader )[1]+'\n', ) - with open(str(testcase_artifact_dir.joinpath('VALUES_UUID')), 'wt') as f: - for expr_val in result_list: - if expr_val.value is not NoValue: - f.write(expr_val.value_uuid + '\n') + def format_uuid(expr_val_list): + uuid_list = sorted( + expr_val.uuid + for expr_val in expr_val_list + ) + return '\n'.join(uuid_list) + + def write_uuid(path, *args): + with path.open('wt') as f: + f.write(format_uuid(*args) + '\n') - if expr_val.excep is not NoValue: - f.write(expr_val.excep_uuid + '\n') + write_uuid(expr_artifact_dir/'VALUES_UUID', result_list) + write_uuid(expr_artifact_dir/'REUSED_VALUES_UUID', reused_expr_val_set) + write_uuid(expr_artifact_dir/'COMPUTED_VALUES_UUID', computed_expr_val_set) - obj_store = engine.ObjectStore( - engine.Expression.get_all_serializable_vals( - testcase_list, hidden_callable_set, + db = engine.ValueDB( + engine.FrozenExprValSeq.from_expr_list( + expr_list, hidden_callable_set=hidden_callable_set ) ) - db = engine.StorageDB(obj_store) - db_path = artifact_dir.joinpath('storage.yml.gz') + db_path = artifact_dir/DB_FILENAME db.to_path(db_path) out('#'*80) info('Artifacts dir: {}'.format(artifact_dir)) info('Result summary:') - # Display the results - adaptor.process_results(result_map) + # Display the results summary + summary = adaptor.get_summary(result_map) + out(summary) + with (artifact_dir/'SUMMARY').open('wt', encoding='utf-8') as f: + f.write(summary + '\n') # Output the merged script with all subscripts - script_path = artifact_dir.joinpath('all_scripts.py') + script_path = artifact_dir/'ALL_SCRIPTS.py' result_name_map, all_scripts = engine.Expression.get_all_script( - testcase_list, prefix='testcase', + expr_list, prefix='testcase', db_path=db_path.relative_to(artifact_dir), db_relative_to='__file__', - obj_store=obj_store, + db=db, db_loader=db_loader, ) - with open(str(script_path), 'wt', encoding='utf-8') as f: - f.write(all_scripts+'\n') + with script_path.open('wt', encoding='utf-8') as f: + f.write(all_scripts + '\n') + + return 0 SILENT_EXCEPTIONS = (KeyboardInterrupt, BrokenPipeError) GENERIC_ERROR_CODE = 1 diff --git a/tools/exekall/exekall/utils.py b/tools/exekall/exekall/utils.py index f92bfe4774d485a8773b2c7b8f67fda344abe3eb..caedc03405fe9e6d7cd47ea40c73f72f2706fd63 100644 --- a/tools/exekall/exekall/utils.py +++ b/tools/exekall/exekall/utils.py @@ -25,7 +25,6 @@ import exekall.engine as engine # Re-export all _utils here from exekall._utils import * -from exekall.engine import take_first def get_callable_set(module_set, verbose=False): # We keep the search local to the packages these modules are defined in, to @@ -68,7 +67,7 @@ def _get_callable_set(module, verbose): # anyway. if inspect.isabstract(return_type): log_f = info if verbose else debug - log_f('Class {} is ignored since it has non-implemented abstract methods'.format( + log_f('Instances of {} will not be created since it has non-implemented abstract methods'.format( get_name(return_type, full_qual=True) )) else: @@ -79,6 +78,8 @@ def sweep_number( callable_, param, start, stop=None, step=1): + step = step if step > 0 else 1 + annot = engine.Operator(callable_).get_prototype()[0] try: type_ = annot[param] diff --git a/tools/exekall/setup.py b/tools/exekall/setup.py index 08db5ba63e8dde98bab68debeaac763b2d234ca0..57e98826f9b01e5fcf4dba778cc95ca2f24ddb59 100755 --- a/tools/exekall/setup.py +++ b/tools/exekall/setup.py @@ -36,12 +36,6 @@ setup( 'console_scripts': ['exekall=exekall.main:main'], }, python_requires='>= 3.5', - install_requires=[ - # Older versions will have troubles with serializing complex nested - # objects hierarchy implementing custom __getstate__ and __setstate__ - "ruamel.yaml >= 0.15.81", - ], - classifiers=[ "Programming Language :: Python :: 3 :: Only", # This is not a standard classifier, as there is nothing defined for