import asyncio
import inspect
import json
import signal
from collections import defaultdict
from collections.abc import Callable, Coroutine, Sequence
from functools import partial
from pathlib import Path
from typing import Annotated, Any, Optional, get_type_hints
import typer
from IPython.terminal.embed import InteractiveShellEmbed
from pydantic import BaseModel, ValidationError, create_model
from ruamel.yaml import YAML
from fastcs import __version__
from fastcs.attribute_io_ref import AttributeIORef
from fastcs.logging import (
GraylogEndpoint,
GraylogEnvFields,
GraylogStaticFields,
LogLevel,
configure_logging,
parse_graylog_env_fields,
parse_graylog_static_fields,
)
from fastcs.logging import logger as _fastcs_logger
from fastcs.tracer import Tracer
from .attributes import ONCE, AttrR, AttrW
from .controller import BaseController, Controller
from .controller_api import ControllerAPI
from .cs_methods import Command, Put, Scan
from .datatypes import T
from .exceptions import FastCSError, LaunchError
from .transport import Transport
from .util import validate_hinted_attributes
tracer = Tracer(name=__name__)
logger = _fastcs_logger.bind(logger_name=__name__)
[docs]
class FastCS:
"""For launching a controller with given transport(s) and keeping
track of tasks during serving."""
def __init__(
self,
controller: Controller,
transports: Sequence[Transport],
loop: asyncio.AbstractEventLoop | None = None,
):
self._loop = loop or asyncio.get_event_loop()
self._controller = controller
self._initial_coros = [controller.connect]
self._scan_tasks: set[asyncio.Task] = set()
# these initialise the controller & build its APIs
self._loop.run_until_complete(controller.initialise())
self._loop.run_until_complete(controller.attribute_initialise())
validate_hinted_attributes(controller)
self.controller_api = build_controller_api(controller)
self._link_process_tasks()
self._transports = transports
for transport in self._transports:
transport.initialise(controller_api=self.controller_api, loop=self._loop)
def create_docs(self) -> None:
for transport in self._transports:
transport.create_docs()
def create_gui(self) -> None:
for transport in self._transports:
transport.create_gui()
def run(self):
serve = asyncio.ensure_future(self.serve())
self._loop.add_signal_handler(signal.SIGINT, serve.cancel)
self._loop.add_signal_handler(signal.SIGTERM, serve.cancel)
self._loop.run_until_complete(serve)
def _link_process_tasks(self):
for controller_api in self.controller_api.walk_api():
_link_put_tasks(controller_api)
def __del__(self):
self._stop_scan_tasks()
async def serve_routines(self):
scans, initials = _get_scan_and_initial_coros(self.controller_api)
self._initial_coros += initials
await self._run_initial_coros()
await self._start_scan_tasks(scans)
async def _run_initial_coros(self):
for coro in self._initial_coros:
await coro()
async def _start_scan_tasks(
self, coros: list[Callable[[], Coroutine[None, None, None]]]
):
self._scan_tasks = {self._loop.create_task(coro()) for coro in coros}
for task in self._scan_tasks:
task.add_done_callback(self._scan_done)
def _scan_done(self, task: asyncio.Task):
try:
task.result()
except Exception as e:
raise FastCSError(
"Exception raised in scan method of "
f"{self._controller.__class__.__name__}"
) from e
def _stop_scan_tasks(self):
for task in self._scan_tasks:
if not task.done():
try:
task.cancel()
except (asyncio.CancelledError, RuntimeError):
pass
except Exception as e:
raise RuntimeError("Unhandled exception in stop scan tasks") from e
async def serve(self) -> None:
coros = [self.serve_routines()]
context = {
"controller": self._controller,
"controller_api": self.controller_api,
"transports": [
transport.__class__.__name__ for transport in self._transports
],
}
for transport in self._transports:
coros.append(transport.serve())
common_context = context.keys() & transport.context.keys()
if common_context:
raise RuntimeError(
"Duplicate context keys found between "
f"current context { ({k: context[k] for k in common_context}) } "
f"and {transport.__class__.__name__} context: "
f"{ ({k: transport.context[k] for k in common_context}) }"
)
context.update(transport.context)
coros.append(self._interactive_shell(context))
logger.info(
"Starting FastCS",
controller=self._controller,
transports=f"[{', '.join(str(t) for t in self._transports)}]",
)
try:
await asyncio.gather(*coros)
except asyncio.CancelledError:
pass
except Exception as e:
raise RuntimeError("Unhandled exception in serve") from e
async def _interactive_shell(self, context: dict[str, Any]):
"""Spawn interactive shell in another thread and wait for it to complete."""
def run(coro: Coroutine[None, None, None]):
"""Run coroutine on FastCS event loop from IPython thread."""
def wrapper():
asyncio.create_task(coro)
self._loop.call_soon_threadsafe(wrapper)
async def interactive_shell(
context: dict[str, object], stop_event: asyncio.Event
):
"""Run interactive shell in a new thread."""
shell = InteractiveShellEmbed()
await asyncio.to_thread(partial(shell.mainloop, local_ns=context))
stop_event.set()
context["run"] = run
stop_event = asyncio.Event()
self._loop.create_task(interactive_shell(context, stop_event))
await stop_event.wait()
def _link_put_tasks(controller_api: ControllerAPI) -> None:
for name, method in controller_api.put_methods.items():
name = name.removeprefix("put_")
attribute = controller_api.attributes[name]
match attribute:
case AttrW():
attribute.add_process_callback(method.fn)
case _:
raise FastCSError(
f"Attribute type {type(attribute)} does not"
f"support put operations for {name}"
)
def _get_scan_and_initial_coros(
root_controller_api: ControllerAPI,
) -> tuple[list[Callable], list[Callable]]:
scan_dict: dict[float, list[Callable]] = defaultdict(list)
initial_coros: list[Callable] = []
for controller_api in root_controller_api.walk_api():
_add_scan_method_tasks(scan_dict, controller_api)
_add_attribute_updater_tasks(scan_dict, initial_coros, controller_api)
scan_coros = _get_periodic_scan_coros(scan_dict)
return scan_coros, initial_coros
def _add_scan_method_tasks(
scan_dict: dict[float, list[Callable]], controller_api: ControllerAPI
):
for method in controller_api.scan_methods.values():
scan_dict[method.period].append(method.fn)
def _add_attribute_updater_tasks(
scan_dict: dict[float, list[Callable]],
initial_coros: list[Callable],
controller_api: ControllerAPI,
):
for attribute in controller_api.attributes.values():
match attribute:
case (
AttrR(_io_ref=AttributeIORef(update_period=update_period)) as attribute
):
callback = _create_updater_callback(attribute)
if update_period is ONCE:
initial_coros.append(callback)
elif update_period is not None:
scan_dict[update_period].append(callback)
def _create_updater_callback(attribute: AttrR[T]):
async def callback():
try:
tracer.log_event("Call attribute updater", topic=attribute)
await attribute.update()
except Exception:
logger.opt(exception=True).error("Update loop failed", attribute=attribute)
raise
return callback
def _get_periodic_scan_coros(scan_dict: dict[float, list[Callable]]) -> list[Callable]:
periodic_scan_coros: list[Callable] = []
for period, methods in scan_dict.items():
periodic_scan_coros.append(_create_periodic_scan_coro(period, methods))
return periodic_scan_coros
def _create_periodic_scan_coro(period, methods: list[Callable]) -> Callable:
async def _sleep():
await asyncio.sleep(period)
methods.append(_sleep) # Create periodic behavior
async def scan_coro() -> None:
while True:
await asyncio.gather(*[method() for method in methods])
return scan_coro
def build_controller_api(controller: Controller) -> ControllerAPI:
return _build_controller_api(controller, [])
def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI:
scan_methods: dict[str, Scan] = {}
put_methods: dict[str, Put] = {}
command_methods: dict[str, Command] = {}
for attr_name in dir(controller):
attr = getattr(controller, attr_name)
match attr:
case Put(enabled=True):
put_methods[attr_name] = attr
case Scan(enabled=True):
scan_methods[attr_name] = attr
case Command(enabled=True):
command_methods[attr_name] = attr
case _:
pass
return ControllerAPI(
path=path,
attributes=controller.attributes,
scan_methods=scan_methods,
put_methods=put_methods,
command_methods=command_methods,
sub_apis={
name: _build_controller_api(sub_controller, path + [name])
for name, sub_controller in controller.get_sub_controllers().items()
},
description=controller.description,
)
[docs]
def launch(
controller_class: type[Controller],
version: str | None = None,
) -> None:
"""
Serves as an entry point for starting FastCS applications.
By utilizing type hints in a Controller's __init__ method, this
function provides a command-line interface to describe and gather the
required configuration before instantiating the application.
Args:
controller_class (type[Controller]): The FastCS Controller to instantiate.
It must have a type-hinted __init__ method and no more than 2 arguments.
version (Optional[str]): The version of the FastCS Controller.
Optional
Raises:
LaunchError: If the class's __init__ is not as expected
Example of the expected Controller implementation:
class MyController(Controller):
def __init__(self, my_arg: MyControllerOptions) -> None:
...
Typical usage:
if __name__ == "__main__":
launch(MyController)
"""
_launch(controller_class, version)()
def _launch(
controller_class: type[Controller],
version: str | None = None,
) -> typer.Typer:
fastcs_options = _extract_options_model(controller_class)
launch_typer = typer.Typer()
class LaunchContext:
def __init__(self, controller_class, fastcs_options):
self.controller_class = controller_class
self.fastcs_options = fastcs_options
def version_callback(value: bool):
if value:
if version:
print(f"{controller_class.__name__}: {version}")
print(f"FastCS: {__version__}")
raise typer.Exit()
@launch_typer.callback()
def main(
ctx: typer.Context,
version: Optional[bool] = typer.Option( # noqa (Optional required for typer)
None,
"--version",
callback=version_callback,
is_eager=True,
help=f"Display the {controller_class.__name__} version.",
),
):
ctx.obj = LaunchContext(
controller_class,
fastcs_options,
)
@launch_typer.command(help=f"Produce json schema for a {controller_class.__name__}")
def schema(ctx: typer.Context):
system_schema = ctx.obj.fastcs_options.model_json_schema()
print(json.dumps(system_schema, indent=2))
@launch_typer.command(help=f"Start up a {controller_class.__name__}")
def run(
ctx: typer.Context,
config: Annotated[
Path,
typer.Argument(
help=f"A yaml file matching the {controller_class.__name__} schema"
),
],
log_level: Annotated[
Optional[LogLevel], # noqa: UP045
typer.Option(),
] = None,
graylog_endpoint: Annotated[
Optional[GraylogEndpoint], # noqa: UP045
typer.Option(
help="Endpoint for graylog logging - '<host>:<port>'",
parser=GraylogEndpoint.parse_graylog_endpoint,
),
] = None,
graylog_static_fields: Annotated[
Optional[GraylogStaticFields], # noqa: UP045
typer.Option(
help="Fields to add to graylog messages with static values",
parser=parse_graylog_static_fields,
),
] = None,
graylog_env_fields: Annotated[
Optional[GraylogEnvFields], # noqa: UP045
typer.Option(
help="Fields to add to graylog messages from environment variables",
parser=parse_graylog_env_fields,
),
] = None,
):
"""
Start the controller
"""
configure_logging(
log_level, graylog_endpoint, graylog_static_fields, graylog_env_fields
)
controller_class = ctx.obj.controller_class
fastcs_options = ctx.obj.fastcs_options
yaml = YAML(typ="safe")
options_yaml = yaml.load(config)
try:
instance_options = fastcs_options.model_validate(options_yaml)
except ValidationError as e:
if any("transport" in error["loc"] for error in json.loads(e.json())):
raise LaunchError(
"Failed to validate transports. "
"Are the correct fastcs extras installed? "
f"Available transports:\n{Transport.subclasses}",
) from e
raise LaunchError("Failed to validate config") from e
if hasattr(instance_options, "controller"):
controller = controller_class(instance_options.controller)
else:
controller = controller_class()
instance = FastCS(
controller,
instance_options.transport,
loop=asyncio.get_event_loop(),
)
instance.create_gui()
instance.create_docs()
instance.run()
return launch_typer
def _extract_options_model(controller_class: type[Controller]) -> type[BaseModel]:
sig = inspect.signature(controller_class.__init__)
args = inspect.getfullargspec(controller_class.__init__)[0]
if len(args) == 1:
fastcs_options = create_model(
f"{controller_class.__name__}",
transport=(list[Transport.union()], ...),
__config__={"extra": "forbid"},
)
elif len(args) == 2:
hints = get_type_hints(controller_class.__init__)
if "return" in hints:
del hints["return"]
if hints:
options_type = list(hints.values())[-1]
else:
raise LaunchError(
f"Expected typehinting in '{controller_class.__name__}"
f".__init__' but received {sig}. Add a typehint for `{args[-1]}`."
)
fastcs_options = create_model(
f"{controller_class.__name__}",
controller=(options_type, ...),
transport=(list[Transport.union()], ...),
__config__={"extra": "forbid"},
)
else:
raise LaunchError(
f"Expected no more than 2 arguments for '{controller_class.__name__}"
f".__init__' but received {len(args)} as `{sig}`"
)
return fastcs_options
[docs]
def get_controller_schema(target: type[Controller]) -> dict[str, Any]:
"""Gets schema for a give controller for serialisation."""
options_model = _extract_options_model(target)
target_schema = options_model.model_json_schema()
return target_schema