diff --git a/lisa/utils.py b/lisa/utils.py index 22969d3778f01c18ee6f85439b59125b2a48bfb4..6076d9f2679a09e92a8c4e7262cc3bb3ce750d62 100644 --- a/lisa/utils.py +++ b/lisa/utils.py @@ -69,6 +69,7 @@ import builtins import typing import ruamel.yaml +import ruamel.yaml.nodes from ruamel.yaml import YAML # These modules may not be installed as they are only used for notebook usage @@ -1025,7 +1026,32 @@ class Serializable( @classmethod def _get_yaml(cls, typ): - yaml = YAML(typ=typ) + # Make a fresh class, in case there is "class global" behaviour that + # really should be instance-related. + class _YAML(YAML): + @property + @memoized + def constructor(self): + # This will rightfully raise in case constructor() is called + # from YAML.__init__(), so that we do not accidentally memoized + # an instance of the wrong type. + return _Constructor() + + # If the user requested an unsafe instance, we provide a safe instance + # with a re-implementation of some unsafe bits. This is because + # ruamel.yaml deprecated typ='unsafe': + # https://yaml.readthedocs.io/en/latest/#ruamelyaml + yaml = _YAML(typ='safe' if typ == 'unsafe' else typ) + + # Ensure we get a fresh constructor class, since add_constructor() and + # add_multi_constructor() are apparently class-global. + ctor = yaml.Constructor + if ctor is None: + _Constructor = ctor + else: + class _Constructor(ctor): + pass + yaml.Constructor = _Constructor # If allow_unicode=True, true unicode characters will be written to the # file instead of being replaced by escape sequence. @@ -1033,25 +1059,41 @@ class Serializable( yaml.default_flow_style = False yaml.indent = 4 - # Replace unknown tags by a placeholder object containing the data. - # This happens when the class was not imported at the time the object - # was deserialized - yaml.constructor.add_constructor(None, cls._yaml_unknown_tag_constructor) - yaml.constructor.add_constructor('!untrusted', cls._yaml_untrusted_constructor) - - if typ == 'unsafe': - yaml.constructor.add_constructor('!include', functools.partial(cls._yaml_include_constructor, parser_typ=typ, subparser_typ='unsafe')) - yaml.constructor.add_constructor('!include-untrusted', functools.partial(cls._yaml_include_constructor, parser_typ=typ, subparser_typ='safe')) - yaml.constructor.add_constructor('!var', cls._yaml_var_constructor) - yaml.constructor.add_multi_constructor('!env:', cls._yaml_env_var_constructor) - yaml.constructor.add_multi_constructor('!call:', cls._yaml_call_constructor) - - return yaml + # typ='full' does not allow loading, and will raise when trying to add any constructor + if typ == 'full': + return yaml + else: + # Replace unknown tags by a placeholder object containing the data. + # This happens when the class was not imported at the time the object + # was deserialized + yaml.constructor.add_constructor(None, cls._yaml_unknown_tag_constructor) + yaml.constructor.add_constructor('!untrusted', cls._yaml_untrusted_constructor) + + if typ == 'unsafe': + yaml.constructor.add_constructor('!include', functools.partial(cls._yaml_include_constructor, parser_typ=typ, subparser_typ='unsafe')) + yaml.constructor.add_constructor('!include-untrusted', functools.partial(cls._yaml_include_constructor, parser_typ=typ, subparser_typ='safe')) + yaml.constructor.add_constructor('!var', cls._yaml_var_constructor) + yaml.constructor.add_multi_constructor('!env:', cls._yaml_env_var_constructor) + yaml.constructor.add_multi_constructor('!call:', cls._yaml_call_constructor) + + # Implement the tags that are in use in ruamel.yaml + # representer.py source. constructor.py seems to be able to + # recognize more tags than that, but they are probably only + # emitted by older versions of the library we do not care + # about. + yaml.constructor.add_multi_constructor('tag:yaml.org,2002:python/object:', functools.partial(cls._yaml_object_constructor, kind='object')) + yaml.constructor.add_multi_constructor('tag:yaml.org,2002:python/object/apply:', functools.partial(cls._yaml_object_constructor, kind='object/apply')) + yaml.constructor.add_multi_constructor('tag:yaml.org,2002:python/object/new:', functools.partial(cls._yaml_object_constructor, kind='object/new')) + yaml.constructor.add_multi_constructor('tag:yaml.org,2002:python/name:', functools.partial(cls._yaml_object_constructor, kind='name')) + yaml.constructor.add_multi_constructor('tag:yaml.org,2002:python/module:', functools.partial(cls._yaml_object_constructor, kind='module')) + yaml.constructor.add_constructor('tag:yaml.org,2002:python/tuple', cls._yaml_tuple_constructor) + yaml.constructor.add_constructor('tag:yaml.org,2002:python/complex', cls._yaml_complex_constructor) + return yaml @classmethod def _yaml_untrusted_constructor(cls, loader, node): if isinstance(node.value, str): - return YAML(typ='safe').load(node.value) + return cls._get_yaml(typ='safe').load(node.value) else: raise TypeError(f'!untrusted node value must be a string. Instead we got a {node.value.__class__.__name__}: {node.value}') @@ -1094,7 +1136,85 @@ class Serializable( if args: _, args = zip(*sorted(args.items(), key=itemgetter(0))) - return loader.make_python_instance(suffix, node, args=args, kwds=kwargs, newobj=False) + f = resolve_dotted_name(suffix) + return f(*args, **kwargs) + + @classmethod + def _yaml_tuple_constructor(cls, loader, node): + return tuple(loader.construct_sequence(node)) + + @classmethod + def _yaml_complex_constructor(cls, loader, node): + return complex(loader.construct_scalar(node)) + + @classmethod + def _yaml_object_constructor(cls, loader, suffix, node, kind): + def setstate(instance, state): + """ + Implement https://docs.python.org/3/library/pickle.html#object.__getstate__ + """ + try: + _setstate = instance.__setstate__ + except AttributeError: + def _setstate(state): + if state is None: + return instance + elif isinstance(state, dict): + instance.__dict__.update(state) + elif isinstance(state, tuple): + assert len(state) == 2 + dct, slots = state + if dct: + instance.__dict__.update(dct) + for k, v in slots.items(): + setattr(instance, k, v) + else: + raise ValueError(f'Non handled state: {state}') + + _setstate(state) + + f = resolve_dotted_name(suffix) + + if kind == 'object': + _cls = f + assert isinstance(_cls, type) + instance = _cls.__new__(_cls) + + loader.recursive_objects[node] = instance + yield instance + + deep = hasattr(instance, '__setstate__') + state = loader.construct_mapping(node, deep=deep) + setstate(instance, state) + elif kind in ('object/apply', 'object/new'): + if kind == 'object/new': + _cls = f + assert isinstance(_cls, type) + f = lambda *args, **kwargs: _cls.__new__(_cls, *args, **kwargs) + + if isinstance(node, ruamel.yaml.nodes.SequenceNode): + args = loader.construct_sequence(node, deep=True) + instance = f(*args) + else: + value = loader.construct_mapping(node, deep=True) + args = value.get('args', []) + kwargs = value.get('kwds', {}) + state = value.get('state', {}) + listitems = value.get('listitems', []) + dictitems = value.get('dictitems', {}) + + instance = f(*args, **kwargs) + setstate(instance, state) + if listitems: + instance.extend(listitems) + if dictitems: + for k, v in dictitems.items(): + instance[k] = v + yield instance + elif kind in ('name', 'module'): + yield f + else: + raise ValueError(f'Unknown reloading kind: {kind}') # Allow !include to use relative paths from the current file. Since we # introduce a global state, we use thread-local storage. @@ -1143,7 +1263,7 @@ class Serializable( else: varname = string - type_ = loader.find_python_name(type_, node.start_mark) + type_ = resolve_dotted_name(type_) assert callable(type_) try: value = os.environ[varname] @@ -1163,7 +1283,7 @@ class Serializable( def _yaml_var_constructor(cls, loader, node): varname = loader.construct_scalar(node) assert isinstance(varname, str) - return loader.find_python_name(varname, node.start_mark) + return resolve_dotted_name(varname) def to_path(self, filepath, fmt=None): """ @@ -1188,10 +1308,10 @@ class Serializable( yaml_kwargs = dict(mode='w', encoding=cls.YAML_ENCODING) if fmt == 'yaml': kwargs = yaml_kwargs - # Dumping in unsafe mode allows creating !!python/object tags, but + # Dumping in full mode allows creating !!python/object tags, but # since it will not load anything that is not already available in # memory there is no security issue. - dumper = cls._get_yaml('unsafe').dump + dumper = cls._get_yaml('full').dump elif fmt == 'yaml-roundtrip': kwargs = yaml_kwargs dumper = cls._get_yaml('rt').dump @@ -1211,7 +1331,7 @@ class Serializable( @classmethod def _to_yaml(cls, data): - yaml = cls._get_yaml('unsafe') + yaml = cls._get_yaml('full') buff = io.StringIO() yaml.dump(data, buff) return buff.getvalue() @@ -1252,9 +1372,6 @@ class Serializable( if fmt == 'yaml': kwargs = dict(mode='r', encoding=cls.YAML_ENCODING) loader = cls._get_yaml('unsafe').load - elif fmt == 'yaml-untrusted': - kwargs = dict(mode='r', encoding=cls.YAML_ENCODING) - loader = cls._get_yaml('safe').load elif fmt == 'pickle': kwargs = dict(mode='rb') loader = pickle.load diff --git a/tests/test_serialize.py b/tests/test_serialize.py new file mode 100644 index 0000000000000000000000000000000000000000..373b9820238f166aa84d2bcc4df8f770cdf54852 --- /dev/null +++ b/tests/test_serialize.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (C) 2024, Arm Limited and contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pathlib import Path +from tempfile import NamedTemporaryFile +from io import StringIO +import json + +from .utils import StorageTestCase + +from lisa.utils import Serializable, UnknownTagPlaceholder + + +class EqDict: + def __eq__(self, other): + assert self.__class__ == other.__class__ + return self.__dict__ == other.__dict__ + + +class MySerializable(EqDict, Serializable): + def __init__(self, x): + self.x = x + + +class MyGetState(MySerializable): + def __getstate__(self): + # We have to wrap in a dict, otherwise we get issues with self.x == + # None (pickle fails and ruamel.yaml incorrectly reloads self.x as an + # empty dict instead). + return {'val': self.x} + + def __setstate__(self, x): + self.x = x['val'] + + +class LoadUser(Serializable): + @classmethod + def from_yaml(cls, path): + return cls._from_path(path, fmt='yaml') + + +class MyClass(EqDict): + pass + + +class MyExcep(Exception): + def __eq__(self, other): + assert self.__class__ is other.__class__ + return self.args == other.args + + +class MyList(list): + pass + + +class MyDict(dict): + pass + + +class TestSerializable(StorageTestCase): + + def _test(self, obj, avoid_fmt=None): + """ + Test that serialization works correctly + """ + def test(obj, fmt): + with NamedTemporaryFile(dir=self.res_dir) as f: + path = Path(f.name) + obj.to_path(path, fmt=fmt) + obj2 = obj.__class__.from_path(path, fmt=fmt) + + assert obj == obj2 + + + fmts = [ + fmt + for fmt in ('yaml', 'pickle') + if fmt not in (avoid_fmt or []) + ] + for wrapper in (MySerializable, MyGetState): + for fmt in fmts: + test(wrapper(obj), fmt) + + def test_int(self): + self._test(42) + + def test_float(self): + self._test(42.42) + + def test_none(self): + self._test(None) + + def test_list(self): + self._test([42, 43]) + + def test_custom_list(self): + # list subclasses are important as they get serialized with + # !!python/object/new and some list items. + self._test(MyList([42, 43])) + + def test_dict(self): + self._test({42: 43}) + + def test_custom_dict(self): + # list subclasses are important as they get serialized with + # !!python/object/new and some dict items. + self._test(MyDict({42: 43})) + + def test_tuple(self): + self._test((42, 43)) + + def test_complex(self): + self._test(complex(42, 43)) + + def test_class(self): + # Test serializing the class, which should serialize by name. + self._test(MyClass) + + def test_exception(self): + # exceptions are important as they get serialized using + # !!python/object/apply + self._test(MyExcep('hello')) + + def test_module(self): + def test(mod): + # pickle does not support modules + return self._test(mod, avoid_fmt={'pickle'}) + + import lisa.utils + test(lisa.utils) + + import lisa + test(lisa) + + def test_include(self): + def load(data, trust): + with NamedTemporaryFile(dir=self.res_dir) as f_trusted: + path_trusted = Path(f_trusted.name) + with NamedTemporaryFile(dir=self.res_dir) as f_untrusted: + path_untrusted = Path(f_untrusted.name) + + path_untrusted.write_text(data) + + include_tag = 'include' if trust else 'include-untrusted' + data_trusted = f'!{include_tag} {json.dumps(str(path_untrusted))}' + path_trusted.write_text(data_trusted) + + return LoadUser.from_yaml(path_trusted) + + # Create an unsafe parser before anything else. This way, if + # ruamel.yaml leaks some state from one parser to the next, we will + # attempt to decode untrusted values using some unsafe parser feature + # and we will get test failures. + assert load('', trust=True) == None + + # smoke test on a builtin value + assert load('42', trust=True) == 42 + assert load('42', trust=False) == 42 + + # Normal include allows untrusted tags + assert load('!!python/object:tests.test_serialize.MyClass {}', trust=True) == MyClass() + + # Check this will not reload as a MyClass() instance, since we are not + # trusting the content of the data. + x = load('!!python/object:tests.test_serialize.MyClass {}', trust=False) + assert isinstance(x, UnknownTagPlaceholder) + assert x.tag == 'tag:yaml.org,2002:python/object:tests.test_serialize.MyClass' + + diff --git a/tests/utils.py b/tests/utils.py index 06ba553d2987b50626363a4e3f57482be0da55b7..2d7586ec9a6d1e61fc09f9fa071ff62d4d35740a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -47,7 +47,7 @@ def create_local_target(): """ :returns: A localhost :class:`lisa.target.Target` instance """ - return Target.from_conf(conf=HOST_TARGET_CONF, plat_info=HOST_PLAT_INFO) + return Target.from_conf(conf=HOST_TARGET_CONF, plat_info=HOST_PLAT_INFO, lazy_platinfo=True) class StorageTestCase(TestCase):