diff --git a/lisa/fuzz.py b/lisa/fuzz.py index 4f91b61d8c5e7f670291fa874f185e6f4bc098a9..96a0e98ad7ecf39b5b1137fd18776a831fefee0f 100644 --- a/lisa/fuzz.py +++ b/lisa/fuzz.py @@ -18,6 +18,10 @@ """ Fuzzing API to build random constrained values. +.. note:: The following example shows a direct use of the :class:`Gen` monad, + but be aware that :mod:`lisa.wlgen.rta` API allows mixing both :class:`Gen` + and RTA DSL into the same coroutine function. + **Example**:: import operator @@ -28,9 +32,7 @@ Fuzzing API to build random constrained values. from lisa.fuzz import Gen, Choice, Int, Float, retry_until # The function must be decorated with Gen.lift() so that "await" gains its - # special meaning. In addition to that, parameters are automatically awaited if - # they are an instance of Gen, and the return value is automatically promoted - # to an instance of Gen if it is not already. + # special meaning. @Gen.lift async def make_task(duration=None): # Draw a value from an iterable. @@ -126,43 +128,16 @@ class Gen(StateMonad, Loggable): Random generator monad inspired by Haskell's QuickCheck. """ def __init__(self, f, name=None): - log_level = logging.DEBUG - logger = self.logger - if logger.isEnabledFor(log_level): - caller_info = inspect.stack()[2] - else: - caller_info = None - - @functools.wraps(f) - def wrapper(state): - for i in itertools.count(1): - try: - x = f(state) - except RetryException: - continue - else: - trials = f'after {i} trials ' if i > 1 else '' - if caller_info: - info = f' ({caller_info.filename}:{caller_info.lineno})' - else: - info = '' - val, _ = x - val = str(val) - sep = '\n' + ' ' * 4 - val = sep + val.replace('\n', sep) + '\n' if '\n' in val else val + ' ' - self.logger.log(log_level, f'Drawn {val}{trials}from {self}{info}') - return x - - self.name = name - super().__init__(wrapper) + self.name = name or f.__qualname__ + super().__init__(f) - class _STATE: + class _State: def __init__(self, rng): self.rng = rng @classmethod def make_state(cls, *, rng=None, seed=None): - return cls._STATE( + return cls._State( rng=rng or random.Random(seed), ) @@ -170,6 +145,25 @@ class Gen(StateMonad, Loggable): name = self.name or self._f.__qualname__ return f'{self.__class__.__qualname__}({name})' + @classmethod + def _wrap_coroutine_f(cls, f): + @functools.wraps(f) + async def wrapper(*args, **kwargs): + for i in itertools.count(1): + try: + x = await f(*args, **kwargs) + except RetryException: + continue + else: + trials = f'after {i} trials ' if i > 1 else '' + val = str(x) + sep = '\n' + ' ' * 4 + val = sep + val.replace('\n', sep) + '\n' if '\n' in val else val + ' ' + cls.get_logger().debug(f'Drawn {val}{trials}from {f.__qualname__}') + return x + + return wrapper + class Choices(Gen): """ diff --git a/lisa/monad.py b/lisa/monad.py index 6e66ee11c7fc883ba7020399ce735aebb3ee8c0f..ecf5a33b0750ae220e3ece3a7a54e0a6a34c3dac 100644 --- a/lisa/monad.py +++ b/lisa/monad.py @@ -60,14 +60,72 @@ This allow composing lifted functions easily import abc import functools import inspect +import contextlib + +from lisa.utils import compose, nullcontext + + +class _StateInitializer: + """ + Wrapper for a state-initializing function, along with the underlying + non-lifted coroutine function so that lifted functions can be composed + naturally with await. + """ + def __init__(self, f, coro_f): + self.f = f + self.coro_f = coro_f + functools.update_wrapper(wrapper=self, wrapped=f) + + def __call__(self, *args, **kwargs): + return self.f(*args, **kwargs) + + def __await__(self): + return (yield from self.coro_f().__await__()) + + def state_init_decorator(self, f): + """ + Decorator used to decorate wrapper that are initializing the state + (i.e. calling :class:`_StateInitializer` instances). + + This is necessary in order for resulting values to be awaitable, so + that composition is preserved. + """ + return self.__class__( + f, + self.coro_f, + ) + + +def _consume(coro): + try: + action = coro.send(None) + except StopIteration as e: + return e.value + else: + if isinstance(action, StateMonad): + extra = f'. The top-level function should be decorated with @{action._MONAD_BASE.__qualname__}.lift' + else: + extra = '' + raise TypeError(f'The coroutine could not be consumed as it contains unhandled action: {action}{extra}') + finally: + coro.close() + + +class _RestartableCoro: + def __init__(self, factory): + self._factory = factory + + @property + def coro(self): + return self._factory() class StateMonad(abc.ABC): """ The state monad. - :param f: Callable that takes the state as parameter and returns an - instance of the monad. + :param f: Callable that takes the state as parameter and returns a tuple + ``(value, new_state)``. :type f: collections.abc.Callable """ @@ -85,11 +143,6 @@ class StateMonad(abc.ABC): # value it sees fit using coro.send(). return (yield self) - def __call__(self, *args, **kwargs): - state = self.make_state(*args, **kwargs) - x, _ = self._f(state) - return x - def __init_subclass__(cls, **kwargs): # The one inheriting directly from StateMonad is the base of the # hierarchy @@ -97,6 +150,24 @@ class StateMonad(abc.ABC): cls._MONAD_BASE = cls super().__init_subclass__(**kwargs) + @classmethod + def _process_coroutine_val(cls, val, state): + """ + Subclasses can override this method to customize the return value of + the user-defined lifted coroutine function. + + This allows subclasses to use the current state to override the value + returned by the user. + + :param val: The value actually returned in the user-defined lifted + coroutine function. + :type val: object + + :param state: The current state. + :type state: object + """ + return val + @classmethod def from_f(cls, *args, **kwargs): """ @@ -123,73 +194,158 @@ class StateMonad(abc.ABC): """ return cls.from_f(lambda state: (x, state)) + + @staticmethod + def _loop(_coro, *, state, cls, consume): + async def factory(): + if isinstance(_coro, _RestartableCoro): + coro = _coro.coro + else: + coro = _coro + + _state = state + next_ = lambda: coro.send(None) + while True: + try: + action = next_() + except StopIteration as e: + val = cls._process_coroutine_val(e.value, state) + break + else: + is_cls = isinstance(action, cls) + try: + if is_cls: + val, _state = action._f(_state) + else: + val = await action + except Exception as e: + # We need an intermediate variable here, since + # "e" is not really bound in this scope. + excep = e + next_ = lambda: coro.throw(excep) + else: + next_ = lambda: coro.send(val) + + if isinstance(val, cls): + val, _ = val._f(_state) + + return val + + # Wrap the coroutine in something that can be called to consume it + # entirely + if consume: + return _consume(factory()) + else: + return _RestartableCoro(factory) + @classmethod - def lift(cls, f): + def _wrap_coroutine_f(cls, f): """ - Decorator used to lift a function into the monad, such that it can take - monadic parameters that will be evaluated in the current state, and - returns a monadic value as well. + Decorator used to wrap user-defined coroutine-functions. + + This allows subclasses of :class:`StateMonad` to handle exceptions + inside user-defined coroutine functions, or do arbitrary other + processing. """ + return f - cls = cls._MONAD_BASE + @classmethod + def lift(cls, f): + """ + Decorator used to lift a coroutine function into the monad. - def run(_f, args, kwargs): - call = lambda: _f(*args, **kwargs) - x = call() - if inspect.iscoroutine(x): - def body(state): - if inspect.getcoroutinestate(x) == inspect.CORO_CLOSED: - _x = call() - else: - _x = x - - next_ = lambda: _x.send(None) - while True: - try: - future = next_() - except StopIteration as e: - val = e.value - break - else: - assert isinstance(future, cls) - try: - val, state = future._f(state) - except Exception as e: - # We need an intermediate variable here, since - # "e" is not really bound in this scope. - excep = e - next_ = lambda: _x.throw(excep) - else: - next_ = lambda: _x.send(val) - - if isinstance(val, cls): - return val._f(state) - else: - return (val, state) - val = cls.from_f(body, name=f.__qualname__) - else: - if isinstance(x, cls): - val = x - else: - val = cls.pure(x) + The decorated coroutine function can be called to set its parameters + values, and will return another callable. This callable will take the + :meth:`StateMon.make_state` method to initialize the state, and will + then run the computation. - return val + .. note:: If a coroutine function is decorated with + :meth:`StateMonad.lift` multiple times for various subclasses, each + state-initializing callable will return the state-initializing + callable of the next level in the decorator stack, starting from + the top. + """ + cls = cls._MONAD_BASE @functools.wraps(f) - def wrapper(*args, **kwargs): - async def _f(*args, **kwargs): - args = [ - (await arg) if isinstance(arg, cls) else arg - for arg in args - ] - kwargs = { - k: (await v) if isinstance(v, cls) else v - for k, v in kwargs.items() - } - return run(f, args, kwargs) - return run(_f, args, kwargs) - + def wrapper(*fargs, **fkwargs): + @functools.wraps(cls.make_state) + def make_state_wrapper(*sargs, _state_monad_private_wrap_coro=None, **skwargs): + _loop = functools.partial( + cls._loop, + cls=cls, + state=cls.make_state(*sargs, **skwargs), + # Only ask _loop to consume the coroutine if we are the + # top-level state monad in the stack + consume=_state_monad_private_wrap_coro is None, + ) + + if _state_monad_private_wrap_coro is None: + _state_monad_private_wrap_coro = lambda x: x + + wrap_coro = compose(cls._wrap_coroutine_f, _state_monad_private_wrap_coro) + + # We found the inner user-defined coroutine, so we just wrap it + # with the loop + if wrapper._state_monad_is_bottom: + return _loop( + _RestartableCoro( + lambda: wrap_coro(f)(*fargs, **fkwargs) + ), + ) + # If we are lifting an already-lifted function, we wrap with + # our loop + else: + def loop_wrapper(*args, **kwargs): + return _loop( + f(*fargs, **fkwargs)( + *args, + **kwargs, + _state_monad_private_wrap_coro=wrap_coro, + ), + ) + return loop_wrapper + + return _StateInitializer( + make_state_wrapper, + # Provide the top-most non lifted function in the decorator + # stack, so we can use it to await from it directly when + # composing lifted functions. + functools.partial( + wrapper._state_monad_coro_f, + *fargs, + **fkwargs, + ) + ) + + def find_user_f(f): + """ + Find the top-most non lifted function in the decorator stack. + """ + _f = f + while True: + # If we find a lifted function, we just pick it from there + try: + return (_f._state_monad_coro_f, False) + except AttributeError: + pass + + try: + _f = _f.__wrapped__ + except AttributeError: + break + + # If we could not find any lifted function, it means we are the + # bottom-most decorator in the stack and we can just take what we + # are given directly + return (f, True) + + # We wrap the coroutine function so that layers will accumulate and no + # _wrap_coroutine_f() will be missed + user_f, is_bottom = find_user_f(f) + wrapper._state_monad_coro_f = cls._wrap_coroutine_f(user_f) + wrapper._state_monad_is_bottom = is_bottom return wrapper @classmethod diff --git a/lisa/wlgen/rta.py b/lisa/wlgen/rta.py index 0d1dcb6cffdef9fa4e08bafff8a2fdf68dbd26c5..d8a36974a120960bf96b97b88048ed9d289201b9 100644 --- a/lisa/wlgen/rta.py +++ b/lisa/wlgen/rta.py @@ -150,9 +150,11 @@ from lisa.utils import ( kwargs_dispatcher, kwargs_forwarded_to, PartialInit, + compose, ) from lisa.wlgen.workload import Workload from lisa.conf import DeferredValueComputationError +from lisa.monad import StateMonad def _to_us(x): @@ -1606,8 +1608,8 @@ def leaf_precedence(val, **kwargs): """ Give precedence to the leaf values when combining with ``&``:: - phase = phase.with_props(prop_meta=({'hello': 'leaf'}) - phase = phase.with_props(prop_meta=leaf_precedence({'hello': 'root'}) + phase = phase.with_props(meta=({'hello': 'leaf'}) + phase = phase.with_props(meta=leaf_precedence({'hello': 'root'}) assert phase['meta'] == {'hello': 'leaf'} This allows setting a property with some kind of default value on a root @@ -1633,7 +1635,7 @@ def override(val, **kwargs): Override a property with the given value, rather than combining it with the property-specific ``&`` implementation:: - phase = phase.with_props(prop_cpus=override({1,2})) + phase = phase.with_props(cpus=override({1,2})) """ return _OverridingValue(val, **kwargs) @@ -1687,7 +1689,7 @@ def delete(): """ Remove the given property from the phase:: - phase = phase.with_props(prop_cpus=delete()) + phase = phase.with_props(cpus=delete()) """ return _DeletingValue() @@ -2458,6 +2460,105 @@ class UclampProperty(ComposableMultiConcretePropertyBase): ) +def task_factory(f): + from lisa.fuzz import Gen + + @functools.wraps(f) + def wrapper(*args, **kwargs): + decorator = compose(_TaskMonad.lift, Gen.lift) + _f = decorator(f)(*args, **kwargs) + + @_f.state_init_decorator + def with_state(rng=None, seed=None): + # First parameters for the Gen monad + # Then parameters for _TaskMonad state + return _f(rng=rng, seed=seed)() + return with_state + return wrapper + + +class _TaskMonad(StateMonad): + class _State: + def __init__(self): + self.levels = [[]] + + @property + def curr_level(self): + return self.levels[-1] + + def add_phase(self, phase): + self.curr_level.append(phase) + + def begin_prop(self): + self.levels.append([]) + + def end_prop(self, props): + phase = self.merge_level() + self.levels.pop() + phase = phase.with_phase_properties(props) + self.curr_level.append(phase) + + def merge_level(self): + level = self.levels[-1] + if level: + return functools.reduce( + operator.add, + level, + ) + else: + return RTAPhase() + + @classmethod + def make_state(cls): + return cls._State() + + @classmethod + def _wrap_coroutine_f(cls, f): + @functools.wraps(f) + async def wrapper(*args, **kwargs): + x = await f(*args, **kwargs) + assert x is None + state = await cls.get_state() + # The return value of the coroutine is the phase that corresponds + # to the current level, so that the top-level coroutine will return + # the top-level phase. + return state.merge_level() + return wrapper + + +class _DSLTaskMonad(_TaskMonad): + def __init__(self, f): + @functools.wraps(f) + def wrapper(state): + f(state) + return (None, state) + super().__init__(wrapper) + + +class _PhaseMonad(_DSLTaskMonad): + def __init__(self, phase): + super().__init__(lambda state: state.add_phase(phase)) + + +class _WloadMonad(_PhaseMonad): + def __init__(self, wload): + phase = RTAPhase(prop_wload=wload) + super().__init__(phase) + + +class Properties: + def __init__(self, **kwargs): + self.props = RTAPhaseProperties.from_polymorphic(kwargs) + + async def __aenter__(self): + await _DSLTaskMonad(lambda state: state.begin_prop()) + return + + async def __aexit__(self, exc_type, exc_value, traceback): + await _DSLTaskMonad(lambda state: state.end_prop(self.props)) + return + + class WloadPropertyBase(ConcretePropertyBase): """ Phase workload. @@ -2480,6 +2581,9 @@ class WloadPropertyBase(ConcretePropertyBase): def val(self): return self + def __await__(self): + return (yield from _WloadMonad(self).__await__()) + def __add__(self, other): """ Adding two workloads together concatenates them. @@ -3175,6 +3279,9 @@ class _RTAPhaseBase: return '\n\n'.join(starmap(make, sorted(properties.items()))) + def __await__(self): + return (yield from _PhaseMonad(self).__await__()) + class RTAPhaseBase(_RTAPhaseBase, SimpleHash, Mapping, abc.ABC): """