Source code for malcolm.core.controller

from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple, Union

from annotypes import Anno, stringify_error

from malcolm.compat import OrderedDict

from .alarm import Alarm
from .camel import camel_to_title
from .concurrency import Queue, RLock, Spawned
from .context import Context
from .errors import FieldError, NotWriteableError, UnexpectedError
from .hook import Hook, Hookable, start_hooks, wait_hooks
from .info import Info
from .models import AttributeModel, BlockModel, MethodLog, MethodModel, Model
from .notifier import Notifier, freeze
from .part import FieldRegistry, InfoRegistry, Part, PartRegistrar
from .request import Get, Post, Put, Request, Subscribe, Unsubscribe
from .response import Response
from .tags import method_return_unpacked, version_tag
from .timestamp import TimeStamp
from .views import Block, make_view

Field = Union[AttributeModel, MethodModel]
CallbackResponses = List[Tuple[Callable[[Response], None], Response]]
if TYPE_CHECKING:
    from .process import Process

# This is a good default value for a timeout. It is used to wait for abort
# below, and is imported in a number of other Controller subclasses
DEFAULT_TIMEOUT = 10.0


with Anno("The Malcolm Resource Identifier for the Block produced"):
    AMri = str
with Anno("Description of the Block produced by the controller"):
    ADescription = str


[docs]class Controller(Hookable): process = None def __init__(self, mri: AMri, description: ADescription = "") -> None: self.set_logger(mri=mri) self.name = mri self.mri = mri self.parts: Dict[str, Part] = OrderedDict() self._lock = RLock() self._block = BlockModel() self._block.meta.set_description(description) self._block.meta.set_label(mri) self._block.meta.set_tags([version_tag()]) self._notifier = Notifier(mri, self._lock, self._block) self._block.set_notifier_path(self._notifier, [mri]) self._write_functions: Dict[str, Callable[..., Any]] = {} self.field_registry = FieldRegistry() self.info_registry = InfoRegistry() def setup(self, process: "Process") -> None: self.process = process self.add_initial_part_fields() def add_part(self, part: Part) -> None: assert ( part.name not in self.parts ), f"Part {part.name!r} already exists in Controller {self.mri!r}" part.setup(PartRegistrar(self.field_registry, self.info_registry, part)) self.parts[part.name] = part def add_block_field( self, name: str, child: Field, writeable_func: Callable[..., Any], needs_context: bool, ) -> None: if writeable_func: if needs_context: # Wrap func def func_wrapper(*args, **kwargs): return writeable_func(Context(self.process), *args, **kwargs) self._write_functions[name] = func_wrapper else: self._write_functions[name] = writeable_func child.meta.set_writeable(True) if not child.meta.label: child.meta.set_label(camel_to_title(name)) self._block.set_endpoint_data(name, child) def add_initial_part_fields(self) -> None: for part_fields in self.field_registry.fields.values(): for name, child, writeable_func, needs_context in part_fields: self.add_block_field(name, child, writeable_func, needs_context) @property # type: ignore @contextmanager def lock_released(self): self._lock.release() try: yield finally: self._lock.acquire() @property def changes_squashed(self): return self._notifier.changes_squashed def block_view(self, context: Context = None) -> Block: if context is None: assert self.process, "No process for context." context = Context(self.process) with self._lock: child_view = make_view(self, context, self._block) return child_view
[docs] def make_view(self, context: Context, data: Model, child_name: str) -> Any: """Make a child View of data[child_name]""" with self._lock: child = data[child_name] child_view = make_view(self, context, child) return child_view
[docs] def handle_request(self, request: Request) -> Spawned: """Spawn a new thread that handles Request""" assert self.process, "No process to handle request" return self.process.spawn(self._handle_request, request)
def _handle_request(self, request: Request) -> None: responses = [] with self._lock: if isinstance(request, Get): handler = self._handle_get elif isinstance(request, Put): handler = self._handle_put elif isinstance(request, Post): handler = self._handle_post elif isinstance(request, Subscribe): handler = self._notifier.handle_subscribe elif isinstance(request, Unsubscribe): handler = self._notifier.handle_unsubscribe else: raise UnexpectedError(f"Unexpected request {request}") try: responses += handler(request) except Exception as e: responses.append(request.error_response(e)) for cb, response in responses: try: cb(response) except Exception: self.log.exception(f"Exception notifying {response}") raise def _handle_get(self, request: Get) -> CallbackResponses: """Called with the lock taken""" data = self._block for i, endpoint in enumerate(request.path[1:]): try: data = data[endpoint] except KeyError: if hasattr(data, "typeid"): typ = data.typeid else: typ = type(data) path = ".".join(request.path[: i + 1]) raise UnexpectedError( f"Object '{path}' of type {typ!r} has no attribute '{endpoint}'" ) # Important to freeze now with the lock so we get a consistent set serialized = freeze(data) ret = [request.return_response(serialized)] return ret def check_field_writeable(self, field): if not field.meta.writeable: raise NotWriteableError(f"Field {field.path} is not writeable") def get_put_function(self, attribute_name): return self._write_functions[attribute_name] def _handle_put(self, request: Put) -> CallbackResponses: """Called with the lock taken""" attribute_name = request.path[1] try: attribute = self._block[attribute_name] except KeyError: raise FieldError(f"Block '{self.mri}' has no Attribute '{attribute_name}'") assert isinstance( attribute, AttributeModel ), f"Cannot Put to {attribute.path} which is a {type(attribute)}" self.check_field_writeable(attribute) put_function = self.get_put_function(attribute_name) value = attribute.meta.validate(request.value) with self.lock_released: result = put_function(value) if request.get and result is None: # We asked for a Get, and didn't get given a return, so do return # the current value. Don't serialize here as value is immutable # (as long as we don't try too hard to break the rules) result = self._block[attribute_name].value elif not request.get: # We didn't ask for a Get, so throw result away result = None ret = [request.return_response(result)] return ret def get_post_function(self, method_name): return self._write_functions[method_name] def update_method_logs( self, method, took_value, took_ts, returned_value, returned_alarm ): with self.changes_squashed: method.set_took( MethodLog( value=method.meta.takes.validate(took_value, add_missing=True), present=[x for x in method.meta.takes.elements if x in took_value], timeStamp=took_ts, ) ) method.set_returned( MethodLog( value=method.meta.returns.validate( returned_value, add_missing=True ), present=[ x for x in method.meta.returns.elements if x in returned_value ], alarm=returned_alarm, ) ) def _handle_post(self, request: Post) -> CallbackResponses: """Called with the lock taken""" method_name = request.path[1] try: method = self._block[method_name] except KeyError: raise FieldError(f"Block '{self.mri}' has no Method '{method_name}'") assert isinstance( method, MethodModel ), f"Cannot Post to {method.path} which is a {type(method)}" self.check_field_writeable(method) post_function = self.get_post_function(method_name) took_ts = TimeStamp() took_value = method.meta.takes.validate(request.parameters) returned_alarm = Alarm.ok returned_value = {} try: with self.lock_released: result = post_function(**took_value) if method_return_unpacked() in method.meta.tags: # Single element, wrap in a dict returned_value = {"return": result} elif result is None: returned_value = {} else: # It should already be an object that serializes to a dict returned_value = result except Exception as e: returned_alarm = Alarm.major(stringify_error(e)) raise finally: self.update_method_logs( method, took_value, took_ts, returned_value, returned_alarm ) # Don't need to freeze as the result should be immutable ret = [request.return_response(result)] return ret def run_hooks(self, hooks: Iterable[Hook]) -> Dict[str, List[Info]]: return self.wait_hooks(*self.start_hooks(hooks)) def start_hooks(self, hooks: Iterable[Hook]) -> Tuple[Queue, List[Hook]]: # Hooks might be a generator, so convert to a list hooks = list(hooks) if not hooks: return Queue(), [] self.log.debug(f"{self.mri}: {hooks[0].name}: Starting hook") assert self.process, "No process for starting hooks" for hook in hooks: hook.set_spawn(self.process.spawn) # Take the lock so that no hook abort can come in between now and # the spawn of the context with self._lock: hook_queue, hook_spawned = start_hooks(hooks) return hook_queue, hook_spawned def wait_hooks( self, hook_queue: Queue, hook_spawned: List[Hook] ) -> Dict[str, List[Info]]: if hook_spawned: return_dict = wait_hooks( self.log, hook_queue, hook_spawned, DEFAULT_TIMEOUT ) else: self.log.debug(f"{self.mri}: No Parts hooked") return_dict = {} return return_dict