Source code for malcolm.core.hook

import inspect
import logging
import time
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
)

from annotypes import Anno, WithCallTypes

from malcolm.compat import OrderedDict

from .concurrency import Queue, Spawned
from .errors import AbortedError
from .info import Info
from .loggable import Loggable

# Create a module level logger
log = logging.getLogger(__name__)


T = TypeVar("T")

Hooked = Callable[..., T]
ArgsGen = Callable[[List[str]], List[str]]


def make_args_gen(func: Callable) -> ArgsGen:
    call_types = getattr(func, "call_types", {})
    arg_spec = inspect.getfullargspec(func)
    need_args = [k for k in arg_spec.args if k != "self"]

    if need_args and not call_types:
        raise TypeError(
            f"Function {func} takes arguments but doesn't have call_types. Did you "
            "forget to decorate with @add_call_types?"
        )

    def args_gen(keys: List[str]) -> List[str]:
        return call_types.keys()

    return args_gen


[docs]class Hookable(Loggable, WithCallTypes): """Baseclass of something that can be attached to a hook""" name: Union[str, None] = None hooked: Union[Dict[Type["Hook"], Tuple[Hooked, ArgsGen]], None] = None
[docs] def register_hooked( self, hooks: Union[Type["Hook"], Sequence[Type["Hook"]]], func: Hooked, args_gen: Optional[ArgsGen] = None, ) -> None: """Register func to be run when any of the hooks are run by parent Args: hooks: A Hook class or list of Hook classes of interest func: The callable that should be run on that Hook args_gen: Optionally specify the argument names that should be passed to func. If not given then use func.call_types.keys """ if self.hooked is None: self.hooked = {} if args_gen is None: args_gen = make_args_gen(func) if not isinstance(hooks, Sequence): hooks = [hooks] for hook_cls in hooks: self.hooked[hook_cls] = (func, args_gen)
[docs] def on_hook(self, hook: "Hook") -> None: """Takes a hook, and optionally calls hook.run on a function""" try: if self.hooked is not None: func, args_gen = self.hooked[type(hook)] else: return except KeyError: return else: hook(func, args_gen)
with Anno("The child that the hook is being passed to"): AHookable = Hookable
[docs]class Hook(Generic[T], WithCallTypes): """Something that children can register with to be called""" def __init__(self, child: AHookable, **kwargs: Any) -> None: self.child = child self._kwargs = kwargs self._queue: Union[Queue, None] = None self._spawn: Union[Callable[..., Spawned], None] = None self.spawned: Union[Spawned, None] = None @property def name(self): return type(self).__name__ def set_spawn(self, spawn: Callable[..., Spawned]) -> "Hook": self._spawn = spawn return self def set_queue(self, queue: Queue) -> "Hook": self._queue = queue return self
[docs] def prepare(self) -> None: """Override this if we need to prepare before running""" pass
def __call__(self, func: Callable[..., T], args_gen: ArgsGen = None) -> None: """Spawn the function, passing kwargs specified by func.call_types or keys if given""" assert ( not self.spawned ), "Hook has already spawned a function, cannot run another" self.prepare() if args_gen is None: args_gen = make_args_gen(func) # TODO: should we check the return types here? supplied = list(self._kwargs) demanded = args_gen(supplied) assert set(supplied).issuperset( demanded ), f"Hook demanded arguments {demanded}, but only supplied {supplied}" kwargs = {k: self._kwargs[k] for k in demanded} assert self._spawn, "No spawned function" self.spawned = self._spawn(self._run, func, kwargs) def _run(self, func: Callable[..., T], kwargs: Dict[str, Any]) -> None: result: Union[T, Exception] try: result = func(**kwargs) result = self.validate_return(result) except AbortedError as e: log.info("%s: %s has been aborted", self.child, func) result = e except Exception as e: # pylint:disable=broad-except log.exception( "%s: %s(**%s) raised exception %s", self.child, func, kwargs, e ) result = e assert self._queue, "No queue to put result" self._queue.put((self, result))
[docs] def stop(self) -> None: """Override this if we can stop""" raise RuntimeError(f"{self.name} cannot be stopped")
[docs] def validate_return(self, ret: T) -> Any: """Override this if the function is expected to return something to to validate its value""" assert not ret, f"Expected no return, got {ret}" return None
def start_hooks(hooks: List[Hook]) -> Tuple[Queue, List[Hook]]: # This queue will hold (part, result) tuples hook_queue = Queue() hook_spawned = [] # now start them off for hook in hooks: hook.set_queue(hook_queue) hook.child.on_hook(hook) if hook.spawned: hook_spawned.append(hook) return hook_queue, hook_spawned def wait_hooks( logger: Optional[logging.Logger], hook_queue: Queue, hook_spawned: List[Hook], timeout: float = None, exception_check: bool = True, ) -> Dict[str, List[Info]]: # timeout is time to wait for spawned processes to complete on abort, # not time for them to run for # Wait for them all to finish return_dict = OrderedDict() for spawned_hook in hook_spawned: return_dict[spawned_hook.child.name] = None start = time.time() hook_spawned_set = set(hook_spawned) while hook_spawned_set: hook: Hook ret: Any hook, ret = hook_queue.get() hook_spawned_set.remove(hook) # Wait for the process to terminate assert hook.spawned, "No spawned process" hook.spawned.wait(timeout) duration = time.time() - start if logger: if hook_spawned_set: logger.debug( "%s: Child %s returned %r after %ss. Still waiting for %s", hook.name, hook.child.name, ret, duration, [h.child.name for h in hook_spawned_set], ) else: logger.debug( "%s: Child %s returned %r after %ss. Returning...", hook.name, hook.child.name, ret, duration, ) if isinstance(ret, Exception) and exception_check: if not isinstance(ret, AbortedError): # If AbortedError, all tasks have already been stopped. # Got an error, so stop and wait all hook runners for h in hook_spawned: h.stop() # Wait for them to finish for h in hook_spawned: assert h.spawned, "No spawned functions" h.spawned.wait(timeout) raise ret else: return_dict[hook.child.name] = ret return return_dict