diff --git a/lisa/conf.py b/lisa/conf.py index 81dca99bb0edb51d5e53488ed92cf577548d61a8..0084b6e7d46f7b6d6bbdb55e4135874f026cfd2c 100644 --- a/lisa/conf.py +++ b/lisa/conf.py @@ -21,8 +21,9 @@ from collections.abc import Mapping, Sequence from collections import OrderedDict import difflib import inspect +import itertools -from lisa.utils import Serializable, Loggable +from lisa.utils import Serializable, Loggable, get_nested_key, set_nested_key class DeferredValue: """ @@ -192,12 +193,97 @@ class KeyDesc(KeyDescBase): prefix=prefix, key=self.name, classinfo=' or '.join( - self._get_cls_name(key_cls, style='rst') + self._get_cls_name(key_cls, style=style) for key_cls in self.classinfo ), help=': ' + self.help if self.help else '' ) +class MissingBaseKeyError(KeyError): + """ + Exception raised when a base key needed to compute a derived key is missing. + """ + pass + +class DerivedKeyDesc(KeyDesc): + """ + Key descriptor describing a key derived from other keys + + Derived keys cannot be added from a source, since they are purely computed + out of other keys. It is also not possible to change their source + priorities. To achieve that, set the source priorities on the keys it is + based on. + + :param base_key_paths: List of paths to the keys this key is derived from. + The paths in the form of a list of string are relative to the current + level, and cannot reference any level above the current one. + :type base_key_paths: list(list(str)) + + :param compute: Function used to compute the value of the key. It takes a + dictionary of base keys specified in ``base_key_paths`` as only + parameter and is expected to return the key's value. + :type compute: collections.abc.Callable + """ + + def __init__(self, name, help, classinfo, base_key_paths, compute): + super().__init__(name=name, help=help, classinfo=classinfo) + self._base_key_paths = base_key_paths + self._compute = compute + + @property + def help(self): + return '(derived from {}) '.format( + ', '.join(sorted( + self._get_base_key_qualname(path) + for path in self._base_key_paths + )) + ) + self._help + + @help.setter + def help(self, val): + self._help = val + + @staticmethod + def _get_base_key_val(conf, path): + return get_nested_key(conf, path) + + @staticmethod + def _get_base_key_src(conf, path): + conf = get_nested_key(conf, path[:-1]) + return conf.resolve_src(path[-1]) + + def _get_base_key_qualname(self, key_path): + return self.parent.qualname + '/' + '/'.join(key_path) + + def _get_base_conf(self, conf): + try: + base_conf = {} + for key_path in self._base_key_paths: + val = self._get_base_key_val(conf, key_path) + set_nested_key(base_conf, key_path, val) + return base_conf + except KeyError as e: + raise MissingBaseKeyError('Missing value for base key "{base_key}" in order to compute derived key "{derived_key}": {msg}'.format( + derived_key=self.qualname, + base_key=e.args[1], + msg=e.args[0], + )) from e + + def compute_val(self, conf): + base_conf = self._get_base_conf(conf) + val = self._compute(base_conf) + self.validate_val(val) + return val + + def get_src(self, conf): + return ','.join( + '{src}({key})'.format( + src=self._get_base_key_src(conf, path), + key=self._get_base_key_qualname(path), + ) + for path in self._base_key_paths + ) + class LevelKeyDesc(KeyDescBase, Mapping): """ Key descriptor defining a hierarchical level in the configuration. @@ -278,7 +364,8 @@ class LevelKeyDesc(KeyDescBase, Mapping): # of nested list avoids getting extra blank line between list items. # That prevents ResStructuredText from thinking each item must be a # paragraph. - suffix = '\n\n..\n\n' + idt if style == 'rst' else '\n' + suffix = '\n\n..\n\n' if style == 'rst' else '\n' + suffix += idt help_ = '{prefix} {key}:{help}{suffix}'.format( prefix=prefix, suffix=suffix, @@ -425,7 +512,6 @@ class MultiSrcConf(MultiSrcConfABC, Loggable, Mapping): def __init__(self, conf=None, src='conf'): self._nested_init( - parent=None, structure=self.STRUCTURE, src_prio=[] ) @@ -440,7 +526,7 @@ class MultiSrcConf(MultiSrcConfABC, Loggable, Mapping): def get_help(cls, *args, **kwargs): return cls.STRUCTURE.get_help(*args, **kwargs) - def _nested_init(self, parent, structure, src_prio): + def _nested_init(self, structure, src_prio): """Called to initialize nested instances of the class for nested configuration levels.""" self._structure = structure @@ -459,7 +545,6 @@ class MultiSrcConf(MultiSrcConfABC, Loggable, Mapping): for key, key_desc in self._structure.items(): if isinstance(key_desc, LevelKeyDesc): self._sublevel_map[key] = self._nested_new( - parent = self, structure = key_desc, src_prio = self._src_prio, ) @@ -562,6 +647,12 @@ class MultiSrcConf(MultiSrcConfABC, Loggable, Mapping): # sublevels have already been initialized when the root object # was created. self._sublevel_map[key].add_src(src, val, filter_none=filter_none, fallback=fallback) + # Derived keys cannot be set, since they are purely derived from + # other keys + elif isinstance(key_desc, DerivedKeyDesc): + raise ValueError('Cannot set a value for a derived key "{key}"'.format( + key=key_desc.qualname, + ), key_desc.qualname) # Otherwise that is a leaf value that we store at that level else: self._key_map.setdefault(key, {})[src] = val @@ -620,18 +711,21 @@ class MultiSrcConf(MultiSrcConfABC, Loggable, Mapping): """ key_desc = self._structure[key] + qual_key = key_desc.qualname if isinstance(key_desc, LevelKeyDesc): - key = key_desc.qualname - raise ValueError('Cannot force source of the sub-level "{key}" in {cls}'.format( - key=key, - cls=type(self).__qualname__ - ), key) - - # None means removing the src override for that key - if src_prio is None: - self._src_override.pop(key, None) + raise ValueError('Cannot force source of the sub-level "{key}"'.format( + key=qual_key, + ), qual_key) + elif isinstance(key_desc, DerivedKeyDesc): + raise ValueError('Cannot force source of a derived key "{key}"'.format( + key=qual_key, + ), qual_key) else: - self._src_override[key] = src_prio + # None means removing the src override for that key + if src_prio is None: + self._src_override.pop(key, None) + else: + self._src_override[key] = src_prio def _get_nested_src_override(self): # Make a copy to avoid modifying it @@ -666,19 +760,23 @@ class MultiSrcConf(MultiSrcConfABC, Loggable, Mapping): return mapping def _resolve_prio(self, key): - if key not in self._key_map: - return [] + key_desc = self._structure[key] - # Get the priority list from the prio override list, or just the - # default prio list - src_list = self._src_override.get(key, self._src_prio) + if isinstance(key_desc, DerivedKeyDesc): + return [key_desc.get_src(self)] + elif key not in self._key_map: + return [] + else: + # Get the priority list from the prio override list, or just the + # default prio list + src_list = self._src_override.get(key, self._src_prio) - # Only include a source if it holds an actual value for that key - src_list = [ - src for src in src_list - if src in self._key_map[key] - ] - return src_list + # Only include a source if it holds an actual value for that key + src_list = [ + src for src in src_list + if src in self._key_map[key] + ] + return src_list def resolve_src(self, key): """ @@ -784,22 +882,33 @@ class MultiSrcConf(MultiSrcConfABC, Loggable, Mapping): if isinstance(key_desc, LevelKeyDesc): return self._sublevel_map[key] - - # Compute the source to use for that key - if src is None: + elif isinstance(key_desc, DerivedKeyDesc): + # Specifying a source is an error for a derived key + if src is not None: + key = key_desc.qualname + raise ValueError('Cannot specify the source when getting "{key}" since it is a derived key'.format( + key=key, + src=src, + ), key) + + val = key_desc.compute_val(self) src = self.resolve_src(key) + else: + # Compute the source to use for that key + if src is None: + src = self.resolve_src(key) - try: - val = self._key_map[key][src] - except KeyError: - key = key_desc.qualname - raise KeyError('Key "{key}" is not available from source "{src}"'.format( - key=key, - src=src, - ), key) + try: + val = self._key_map[key][src] + except KeyError: + key = key_desc.qualname + raise KeyError('Key "{key}" is not available from source "{src}"'.format( + key=key, + src=src, + ), key) - if eval_deferred: - val = self._eval_deferred_val(src, key) + if eval_deferred: + val = self._eval_deferred_val(src, key) try: frame_conf = inspect.stack()[2] @@ -828,9 +937,8 @@ class MultiSrcConf(MultiSrcConfABC, Loggable, Mapping): key_desc = self._structure[key] if isinstance(key_desc, LevelKeyDesc): key = key_desc.qualname - raise ValueError('Key "{key}" is a nested configuration level in {cls}, it does not have a source on its own.'.format( + raise ValueError('Key "{key}" is a nested configuration level, it does not have a source on its own.'.format( key=key, - cls=type(self).__qualname__, ), key) return OrderedDict( @@ -848,7 +956,20 @@ class MultiSrcConf(MultiSrcConfABC, Loggable, Mapping): """ out = [] idt_style = ' ' - for k, v in self.items(eval_deferred=eval_deferred): + + # We add the derived keys when pretty-printing, for the sake of + # completeness. This will not honor eval_deferred for base keys. + def derived_items(): + for key in self._get_derived_key_names(): + try: + yield key, self[key] + except MissingBaseKeyError: + continue + + for k, v in itertools.chain( + self.items(eval_deferred=eval_deferred), + derived_items() + ): v_cls = type(v) is_sublevel = k in self._sublevel_map if is_sublevel: @@ -888,7 +1009,14 @@ class MultiSrcConf(MultiSrcConfABC, Loggable, Mapping): return self.get_key(key) def _get_key_names(self): - return list(self._key_map.keys()) + list(self._sublevel_map.keys()) + return sorted(list(self._key_map.keys()) + list(self._sublevel_map.keys())) + + def _get_derived_key_names(self): + return sorted( + key + for key, key_desc in self._structure.items() + if isinstance(key_desc, DerivedKeyDesc) + ) def __iter__(self): return iter(self._get_key_names()) @@ -910,7 +1038,11 @@ class MultiSrcConf(MultiSrcConfABC, Loggable, Mapping): def _ipython_key_completions_(self): "Allow Jupyter keys completion in interactive notebooks" - return self.keys() + regular_keys = set(self.keys()) + # For autocompletion purposes, we show the derived keys + derived_keys = set(self._get_derived_key_names()) + return sorted(regular_keys + derived_keys) + class GenericContainerMetaBase(type): """ diff --git a/lisa/tests/lisa/test_conf.py b/lisa/tests/lisa/test_conf.py index 1096e9adfaeaf16982c5c0a22a5df318ffc68a4a..6b32d1cd2fa728bd16246fd2c3456e3a848fe6a4 100644 --- a/lisa/tests/lisa/test_conf.py +++ b/lisa/tests/lisa/test_conf.py @@ -20,7 +20,7 @@ import os import copy from unittest import TestCase -from lisa.conf import MultiSrcConf, KeyDesc, LevelKeyDesc, TopLevelKeyDesc, IntList +from lisa.conf import MultiSrcConf, KeyDesc, LevelKeyDesc, TopLevelKeyDesc, IntList, DerivedKeyDesc from lisa.tests.lisa.utils import StorageTestCase, HOST_PLAT_INFO, HOST_TARGET_CONF """ A test suite for the MultiSrcConf subclasses.""" @@ -68,10 +68,18 @@ class TestTargetConf(StorageTestCase, TestMultiSrcConfBase): # Make copies to avoid mutating the original one self.conf = copy.copy(HOST_TARGET_CONF) +def compute_derived(base_conf): + return base_conf['foo'] + sum(base_conf['bar']) + base_conf['sublevel']['subkey'] + INTERNAL_STRUCTURE = ( KeyDesc('foo', 'foo help', [int]), KeyDesc('bar', 'bar help', [IntList]), KeyDesc('multitypes', 'multitypes help', [IntList, str, None]), + LevelKeyDesc('sublevel', 'sublevel help', ( + KeyDesc('subkey', 'subkey help', [int]), + )), + DerivedKeyDesc('derived', 'derived help', [int], + [['foo'], ['bar'], ['sublevel', 'subkey']], compute_derived), ) class TestConf(MultiSrcConf): @@ -84,7 +92,9 @@ class TestConfWithDefault(MultiSrcConf): INTERNAL_STRUCTURE ) - DEFAULT_SRC = {'bar': [0, 1, 2]} + DEFAULT_SRC = { + 'bar': [0, 1, 2], + } class TestMultiSrcConf(TestMultiSrcConfBase): def test_add_src_one_key(self): @@ -118,6 +128,20 @@ class TestTestConf(StorageTestCase, TestMultiSrcConf): with self.assertRaises(KeyError): self.conf['foo'] + def test_derived(self): + conf = copy.deepcopy(self.conf) + conf.add_src('mysrc', {'foo': 1}) + # Two missing base keys + with self.assertRaises(KeyError): + conf['derived'] + conf.add_src('mysrc2', { + 'bar': [1, 2], + 'sublevel': { + 'subkey': 42 + } + }) + self.assertEqual(conf['derived'], 46) + def test_force_src_nested(self): conf = copy.deepcopy(self.conf) conf.add_src('mysrc', {'bar': [6,7]}) @@ -153,7 +177,11 @@ class TestTestConfWithDefault(StorageTestCase, TestMultiSrcConf): self.conf = TestConfWithDefault() def test_default_src(self): - self.assertEqual(dict(self.conf), dict(TestConfWithDefault.DEFAULT_SRC)) + ref = dict(TestConfWithDefault.DEFAULT_SRC) + # A freshly built object still has all the level keys, even if it has + # no leaves + ref['sublevel'] = {} + self.assertEqual(dict(self.conf), ref) def test_add_src_one_key_fallback(self): conf = copy.deepcopy(self.conf) diff --git a/lisa/utils.py b/lisa/utils.py index 8e41ae785cb384a1f11ef08a676c40a5c6c49c06..b9f061e0b77d2f12b4c07035b573636362ee8c5e 100644 --- a/lisa/utils.py +++ b/lisa/utils.py @@ -504,4 +504,53 @@ def deduplicate(seq, keep_last=True, key=lambda x: x): ) return list(reorder(dedup.values())) +def get_nested_key(mapping, key_path): + """ + Get a key in a nested mapping + + :param mapping: The mapping to lookup in + :type mapping: collections.abc.Mapping + + :param key_path: Path to the key in the mapping, in the form of a list of + keys. + :type key_path: list + """ + if not key_path: + return mapping + for key in key_path[:-1]: + mapping = mapping[key] + return mapping[key_path[-1]] + +def set_nested_key(mapping, key_path, val, level=None): + """ + Set a key in a nested mapping + + :param mapping: The mapping to update + :type mapping: collections.abc.MutableMapping + + :param key_path: Path to the key in the mapping, in the form of a list of + keys. + :type key_path: list + + :param level: Factory used when creating a level is needed. By default, + ``type(mapping)`` will be called without any parameter. + :type level: collections.abc.Callable + """ + assert key_path + + if level is None: + # This should work for dict and most basic structures + level = type(mapping) + + for key in key_path[:-1]: + try: + mapping = mapping[key] + except KeyError: + new_level = level() + mapping[key] = new_level + mapping = new_level + + mapping[key_path[-1]] = val + + # vim :set tabstop=4 shiftwidth=4 textwidth=80 expandtab