from operator import itemgetter
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
import numpy as np
from annotypes import Anno, add_call_types
from scanpointgenerator import CompoundGenerator, Point, Points
from malcolm.core import APartName, Block, Context, PartRegistrar
from malcolm.modules import builtin, pmac, scanning
from malcolm.modules.pmac.util import MinTurnaround, get_min_turnaround
from ..doublebuffer import MIN_PULSE, TICK, DoubleBuffer, SequencerRows
from ..util import Trigger
# The SEQ.table attributes that should be present in PANDA.exports
SEQ_TABLES = ("seqTableA", "seqTableB")
with Anno("Scannable name for sequencer input"):
APos = str
# Pull re-used annotypes into our namespace in case we are subclassed
APartName = APartName
AMri = builtin.parts.AMri
AInitialVisibility = builtin.parts.AInitialVisibility
# SeqTriggers processing batch size
BATCH_SIZE: int = 5000
class SeqTriggers:
"""A class that generates Sequencer (SEQ) block triggers from a Points generator."""
def __init__(
self,
generator: CompoundGenerator,
axis_mapping: Dict[str, pmac.infos.MotorInfo],
trigger_enums: Dict[Tuple[str, bool], str],
min_turnaround: float,
) -> None:
self.generator: CompoundGenerator = generator
self.axis_mapping: Dict[str, pmac.infos.MotorInfo] = axis_mapping
self.trigger_enums: Dict[Tuple[str, bool], str] = trigger_enums
self.min_turnaround: float = min_turnaround
self.last_point: Point = None
@staticmethod
def _what_moves_most(
point: Point, axis_mapping: Dict[str, pmac.infos.MotorInfo]
) -> Tuple[str, int, bool]:
"""Return the axis from `axis_mapping` that moves most for this point."""
# TODO: should use new velocity calcs when Giles has finished
# {axis_name: abs(diff_cts)}
diffs = {}
# {axis_name: (compare_cts, increasing)}
compare_increasing = {}
for s, info in axis_mapping.items():
compare_cts = info.in_cts(point.lower[s])
centre_cts = info.in_cts(point.positions[s])
diff_cts = centre_cts - compare_cts
if diff_cts != 0:
diffs[s] = abs(diff_cts)
compare_increasing[s] = (compare_cts, diff_cts > 0)
assert diffs, (
"Can't work out a compare point for %s, maybe none of the axes "
"connected to the PandA are moving during the scan point?" % point.positions
)
# Sort on abs(diff), take the biggest
axis_name = max(diffs.items(), key=itemgetter(1))[0]
compare_cts, increasing = compare_increasing[axis_name]
return axis_name, compare_cts, increasing
def _how_long_moving_wrong_way(
self, axis_name: str, point: Point, increasing: bool
) -> float:
"""Return the duration that the given axis is moving in the opposite direction
from that required for `point`, during the prior turnaround."""
assert self.min_turnaround, f"{self.name}: no MinTurnaround assigned"
min_turnaround = max(self.min_turnaround.time, point.delay_after)
time_arrays, velocity_arrays = pmac.util.profile_between_points(
self.axis_mapping,
self.last_point,
point,
min_turnaround,
self.min_turnaround.interval,
)
info = self.axis_mapping[axis_name]
time_array = time_arrays[info.scannable]
velocity_array = velocity_arrays[info.scannable]
# Work backwards through the velocity array until we are going the
# opposite way
i = 0
for i, v in reversed(list(enumerate(velocity_array))):
# Divide v by resolution so it is in counts
v /= info.resolution
if (increasing and v <= 0) or (not increasing and v >= 0):
# The axis is stationary or going the wrong way at this
# point, so we should be blind before then
assert (
i < len(velocity_array) - 1
), f"Last point of {velocity_array} is wrong direction"
break
blind = time_array[i]
return blind
@staticmethod
def _get_row_indices(points: Points) -> Tuple[np.ndarray, np.ndarray]:
"""Generate list of start and end indices for separate rows.
The first point (index 0) is not registered as a separate row. This is because
for the first row of a scan the triggering is handled separately, and for later
batches of points the initial value is from a previous row.
"""
points_joined = pmac.util.all_points_joined(points)
if points_joined is not None and len(points_joined) > 0:
start_indices = np.nonzero(np.invert(points_joined))[0]
start_indices += 1
else:
start_indices = np.array([])
# end_index = start_index + size
end_indices = np.empty(len(start_indices), dtype=int)
if start_indices.size:
end_indices[:-1] = start_indices[1:]
end_indices[-1] = len(points)
return start_indices, end_indices
@staticmethod
def _create_immediate_rows(durations: Sequence[float]) -> SequencerRows:
"""Generate sequencer rows with 'Immediate' trigger type, from `durations`."""
if len(durations) == 0:
return SequencerRows()
pairwise_equal = np.empty(len(durations), dtype=bool)
pairwise_equal[0] = True # Initial duration starts first row
np.not_equal(durations[:-1], durations[1:], out=pairwise_equal[1:])
start_indices = np.nonzero(pairwise_equal)
seq_durations = durations[start_indices]
seq_lengths = np.diff(np.append(start_indices, len(durations)))
rows = SequencerRows()
for duration, count in zip(seq_durations, seq_lengths):
half_frame = int(round(duration / TICK / 2))
rows.add_seq_entry(count, half_duration=half_frame, live=1)
return rows
def _create_triggered_rows(
self, points: Points, start_index: int, end_index: int, add_blind: bool
) -> SequencerRows:
"""Generate sequencer rows corresponding to a triggered points row."""
initial_point: Point = points[start_index]
half_frame: int = int(round(initial_point.duration / TICK / 2))
rows = SequencerRows()
if self.trigger_enums:
# Position compare
# First row, or rows not joined
# Work out which axis moves most during this point
axis_name, compare_cts, increasing = self._what_moves_most(
initial_point, self.axis_mapping
)
if add_blind:
# How long to be blind for during the turnaround
blind = self._how_long_moving_wrong_way(
axis_name, initial_point, increasing
)
half_blind = int(round(blind / TICK / 2))
rows.add_seq_entry(half_duration=half_blind, dead=1)
# Create a compare point for the next row
rows.add_seq_entry(
trigger=self.trigger_enums[(axis_name, increasing)],
position=compare_cts,
half_duration=half_frame,
live=1,
)
else:
# Row trigger coming in on BITA
if add_blind:
# Produce dead pulse as soon as row has finished
rows.add_seq_entry(
half_duration=MIN_PULSE, dead=1, trigger=Trigger.BITA_0
)
rows.add_seq_entry(trigger=Trigger.BITA_1, half_duration=half_frame, live=1)
rows.extend(
self._create_immediate_rows(points.duration[start_index + 1 : end_index])
)
return rows
@staticmethod
def _overlapping_points_range(generator, start: int, end: int) -> Iterator[Points]:
"""Yield a series of `Points` objects that cover the given range.
Yielded points overlap by one point to indicate whether the start of a given
batch corresponds to the middle of a row.
"""
if start == end:
return
low_index = start
high_index = min(start + BATCH_SIZE, end)
while True:
yield generator.get_points(low_index, high_index)
if high_index == end:
break
low_index = high_index - 1 # Include final point from previous range
high_index = min(low_index + BATCH_SIZE + 1, end)
def get_rows(self, loaded_up_to: int, scan_up_to: int) -> Iterator[SequencerRows]:
"""Yield a series of `SequencerRows` that correspond to the given range."""
for points in self._overlapping_points_range(
self.generator, loaded_up_to, scan_up_to
):
if not self.axis_mapping:
# No position compare or row triggering required
durations = points.duration[1:] if self.last_point else points.duration
yield self._create_immediate_rows(durations)
self.last_point = points[-1]
else:
start_indices, end_indices = self._get_row_indices(points)
# Handle the first scan row from the current batch of points
end = start_indices[0] if start_indices.size else len(points)
if self.last_point is None:
# This is the beginning of the scan
yield self._create_triggered_rows(points, 0, end, False)
self.last_point = points[end - 1]
else:
# This is the beginning of subsequent batches.
# Sequence table rows are only added here if the previous batch
# of points finished in the middle of a continuous scan row.
# The first point of the current batch is from the previous batch.
yield self._create_immediate_rows(points.duration[1:end])
# Remaining scan rows from the current batch of points.
for start_i, end_i in zip(start_indices, end_indices):
yield self._create_triggered_rows(points, start_i, end_i, True)
self.last_point = points[end_i - 1]
rows = SequencerRows()
# add one last dead frame signal
rows.add_seq_entry(half_duration=MIN_PULSE, dead=1)
# add continuous loop to prevent sequencer switch
rows.add_seq_entry(count=0)
yield rows
self.last_point = None
def _get_blocks(context: Context, panda_mri: str) -> List[Block]:
"""Get panda, seqA, and seqB Blocks using the given context."""
# {part_name: export_name}
panda = context.block_view(panda_mri)
seq_part_names = {}
for source, export in panda.exports.value.rows():
if export in SEQ_TABLES:
assert source.endswith(
".table"
), f"Expected export {export} to come from SEQx.table, got {source}"
seq_part_names[source[: -len(".table")]] = export
assert (
tuple(sorted(seq_part_names.values())) == SEQ_TABLES
), f"Expected exported attributes {SEQ_TABLES}, got {panda.exports.value.export}"
# {export_name: mri}
seq_mris = {}
for name, mri, _, _, _ in panda.layout.value.rows():
if name in seq_part_names:
export = seq_part_names[name]
seq_mris[export] = mri
assert sorted(seq_mris) == sorted(
seq_part_names.values()
), f"Couldn't find MRI for some of {seq_part_names.values()}"
blocks = [panda]
blocks += [context.block_view(seq_mris[x]) for x in SEQ_TABLES]
return blocks
def doing_pcomp(row_trigger_value: str) -> bool:
"""Indicate whether the row_trigger is for position compare."""
return row_trigger_value == "Position Compare"
[docs]class PandASeqTriggerPart(builtin.parts.ChildPart):
"""Part for operating a pair of Sequencer (SEQ) blocks in a PandA to do position
compare at the start of each row, and time based pulses within the row.
Needs the following exports:
- seqTableA: table Attribute of the first SEQ block
- seqTableB: table Attribute of the second SEQ block
- seqSetEnable: forceSet Method of an SRGATE that is used to gate both SEQs
- seqReset: forceRst Method of an SRGATE that is used to gate both SEQs
"""
def __init__(
self, name: APartName, mri: AMri, initial_visibility: AInitialVisibility = True
) -> None:
super().__init__(
name, mri, initial_visibility=initial_visibility, stateful=False
)
# Stored generator for positions
self.generator: scanning.hooks.AGenerator = None
# The last index we have loaded
self.loaded_up_to: int = 0
# The last scan point index of the current run
self.scan_up_to: int = 0
# If we are currently loading then block loading more points
self.loading: bool = False
# What is the mapping of scannable name to MotorInfo
self.axis_mapping: Dict[str, pmac.infos.MotorInfo] = {}
# The minimum turnaround time for non-joined points
self.min_turnaround: Optional[MinTurnaround] = None
# {(scannable, increasing): trigger_enum}
self.trigger_enums: Dict[Tuple[str, bool], str] = {}
# The panda Block we will be prodding
self.panda: Optional[Any] = None
# The DoubleBuffer object used to load tables during a scan
self.db_seq_table: Optional[DoubleBuffer] = None
def setup(self, registrar: PartRegistrar) -> None:
super().setup(registrar)
# Hooks
registrar.hook(scanning.hooks.ReportStatusHook, self.on_report_status)
registrar.hook(
(
scanning.hooks.ConfigureHook,
scanning.hooks.SeekHook,
),
self.on_configure,
)
registrar.hook(scanning.hooks.PreRunHook, self.on_pre_run)
registrar.hook(scanning.hooks.RunHook, self.on_run)
registrar.hook(builtin.hooks.ResetHook, self.on_reset)
registrar.hook(
(scanning.hooks.AbortHook, scanning.hooks.PauseHook), self.on_abort
)
registrar.hook(scanning.hooks.PostRunArmedHook, self.post_inner_scan)
[docs] @add_call_types
def on_report_status(
self, context: scanning.hooks.AContext
) -> scanning.hooks.UInfos:
child = context.block_view(self.mri)
# Work out if we need the motor controller to send start of row triggers
# or no triggers
if doing_pcomp(child.rowTrigger.value):
# Doing position compare, don't need any triggers
trigger = scanning.infos.MotionTrigger.NONE
else:
# Waiting for bit at the start of each row, so need this signal
trigger = scanning.infos.MotionTrigger.ROW_GATE
info = scanning.infos.MotionTriggerInfo(trigger)
return info
[docs] def setup_pcomp_dicts(
self, seqa: Block, seqb: Block, axis_mapping: Dict[str, pmac.infos.MotorInfo]
) -> None:
"""Setup the axis mapping and trigger enum attributes for position compare."""
# Check that both sequencers are pointing to the same encoders
seq_pos = {}
for suff in "abc":
# Something like INENC1.VAL or ZERO
seqa_pos_inp = seqa["pos" + suff].value
seqb_pos_inp = seqb["pos" + suff].value
assert (
seqa_pos_inp == seqb_pos_inp
), f"SeqA Pos{suff} = {seqa_pos_inp} != SeqB Pos{suff} = {seqb_pos_inp}"
seq_pos[seqa_pos_inp] = f"POS{suff.upper()}"
# Fix the mres and offsets from the panda positions table
assert self.panda, "No PandA"
positions_table = self.panda.positions.value
for i, name in enumerate(positions_table.name):
try:
pos = seq_pos[name]
except KeyError:
# This is a position not connected to the seq, this is fine
pass
else:
# This is a position that we can compare on, check its
# dataset name which is the scannable name
scannable = positions_table.datasetName[i]
info = axis_mapping.get(scannable, None)
if info:
# We are asked to scan this, so correct its resolution
# and store
info.resolution = positions_table.scale[i]
info.offset = positions_table.offset[i]
self.axis_mapping[scannable] = info
self.trigger_enums[(scannable, True)] = f"{pos}>=POSITION"
self.trigger_enums[(scannable, False)] = f"{pos}<=POSITION"
# Check we have at least one entry
assert self.axis_mapping, (
"None of the seq inputs %s can be mapped to scannable names "
"in %s. Did you define datasetName entries for these rows in "
"the PandA positions table?" % (sorted(seq_pos), sorted(axis_mapping))
)
[docs] def reset_seq(self, context):
"""Reset the PandA sequencer using the given context.
We need to use the correct context when calling this function, as it will
otherwise get blocked.
"""
panda = context.block_view(self.panda_mri)
panda.seqReset()
# Allow CamelCase as these parameters will be serialized
# noinspection PyPep8Naming
[docs] @add_call_types
def on_pre_run(self, context: scanning.hooks.AContext) -> None:
assert self.panda, "No PandA"
assert self.db_seq_table, "No DoubleBuffer"
# If there are MotorInfo's enable seqTableA here so ready to receive the
# first signal (once the motors have moved to the first point).
if self.axis_mapping:
self.panda.seqSetEnable()
[docs] @add_call_types
def on_run(self, context: scanning.hooks.AContext) -> None:
# When there are no MotorInfo's the first row will have Trigger.IMMEDIATE
# so don't enable seqTableA until running.
if not self.axis_mapping:
self.panda.seqSetEnable()
futures = self.db_seq_table.run()
context.wait_all_futures(futures)
[docs] @add_call_types
def on_reset(self, context: builtin.hooks.AContext) -> None:
super().on_reset(context)
self.on_abort(context)
[docs] @add_call_types
def on_abort(self, context: builtin.hooks.AContext) -> None:
try:
self.reset_seq(context)
except (AttributeError, KeyError):
# Ensure we can reset or abort if different design is used for PandA
pass
if self.db_seq_table is not None:
self.db_seq_table.clean_up()
[docs] @add_call_types
def post_inner_scan(
self,
context: scanning.hooks.AContext,
completed_steps: scanning.hooks.ACompletedSteps,
steps_to_do: scanning.hooks.AStepsToDo,
part_info: scanning.hooks.APartInfo,
generator: scanning.hooks.AGenerator,
axesToMove: scanning.hooks.AAxesToMove,
) -> None:
self.on_abort(context)
self.on_configure(
context, completed_steps, steps_to_do, part_info, generator, axesToMove
)