import os
from typing import Dict, Iterator, List, Optional
import h5py
from annotypes import Anno, add_call_types
from scanpointgenerator import CompoundGenerator
from vdsgen import InterleaveVDSGenerator, ReshapeVDSGenerator
from malcolm.core import APartName, BadValueError, Future, Info, PartRegistrar
from malcolm.modules import builtin, scanning
PartInfo = Dict[str, List[Info]]
# If the HDF writer doesn't get new frames in this time (seconds), consider it
# stalled and raise
FRAME_TIMEOUT = 60
# Pull re-used annotypes into our namespace in case we are subclassed
APartName = APartName
AMri = builtin.parts.AMri
AInitialVisibility = builtin.parts.AInitialVisibility
with Anno("name of sum dataset"):
ASumName = str
with Anno("name of secondary dataset (e.g. sum)"):
ASecondaryDataset = str
with Anno("number of FR/FP pairs"):
ANumberPairs = int
with Anno("number of mca datasets to be written in total"):
ANumberDataset = int
dict_filewriter = ["A", "B", "C", "D", "E", "F", "G", "H"]
def greater_than_zero(v: int) -> bool:
return v > 0
def create_dataset_infos(
name: str, generator: CompoundGenerator, filename: str, secondary_set: str
) -> Iterator[Info]:
# Update the dataset tableExternalLink
generator_rank = len(generator.dimensions)
# Add the primary datasource
yield scanning.infos.DatasetProducedInfo(
name=f"{name}.data",
filename=filename,
type=scanning.infos.DatasetType.PRIMARY,
rank=generator_rank + 2,
path="/entry/xspress/data",
uniqueid="/entry/xspress/uid",
)
# Add any setpoint dimensions
for dim in generator.axes:
yield scanning.infos.DatasetProducedInfo(
name=f"{dim}.value_set",
filename=filename,
type=scanning.infos.DatasetType.POSITION_SET,
rank=1,
path=f"/entry/xspress/{dim}_set",
uniqueid="",
)
def create_raw_dataset_infos(
name: str, rank: int, filename: str, num_datasets: int, pairs: int
) -> Iterator[Info]:
for i in range(num_datasets):
yield scanning.infos.DatasetProducedInfo(
name=f"{name}.raw_mca{i}",
filename=filename,
type=scanning.infos.DatasetType.RAW,
rank=rank,
path="/raw/mca" + str(i),
uniqueid="",
)
for i in range(9):
yield scanning.infos.DatasetProducedInfo(
name=f"{name}.scalar_{i}",
filename=filename,
type=scanning.infos.DatasetType.RAW,
rank=rank,
path="/raw/scalar_" + str(i),
uniqueid="",
)
for i in range(pairs):
yield scanning.infos.DatasetProducedInfo(
name=f"{name}.uid{i+1}",
filename=filename,
type=scanning.infos.DatasetType.RAW,
rank=rank,
path="/raw/uid" + str(i),
uniqueid="",
)
yield scanning.infos.DatasetProducedInfo(
name=f"{name}.dtc",
filename=filename,
type=scanning.infos.DatasetType.RAW,
rank=rank,
path="/raw/dtc",
uniqueid="",
)
def files_shape(frames, block_size, file_count):
# all files get at least per_file blocks
per_file = int(frames) / int(file_count * block_size)
# this is the remainder once per_file blocks have been distributed
remainder = int(frames) % int(file_count * block_size)
# distribute the remainder
remainders = [
block_size if remains > block_size else remains
for remains in range(remainder, 0, -block_size)
]
# pad the remainders list with zeros
remainders += [0] * (file_count - len(remainders))
shape = tuple(int(per_file * block_size + remainders[i]) for i in range(file_count))
return shape
def one_vds(
vds_folder,
vds_name,
files,
width,
height,
shape,
generator,
alternates,
block_size,
source_node,
target_node,
d_type,
):
# this vds reshapes from 1 file per data writer to a single 1D data set
gen = InterleaveVDSGenerator(
vds_folder,
files=files,
source={"height": width, "width": height, "dtype": d_type, "shape": shape},
output=vds_name,
source_node=source_node,
target_node="process/" + target_node + "_interleave",
block_size=block_size,
log_level=1,
)
gen.generate_vds()
# this VDS shapes the data to match the dimensions of the scan
gen = ReshapeVDSGenerator(
path=vds_folder,
files=[vds_name],
source_node="process/" + target_node + "_interleave",
target_node=target_node,
output=vds_name,
shape=generator.shape,
alternate=alternates,
log_level=1,
)
gen.generate_vds()
def create_vds(generator, raw_name, vds_path, child, num_datasets):
vds_folder, vds_name = os.path.split(vds_path)
image_width = int(child.imageWidth.value)
image_height = int(child.imageHeight.value)
block_size = int(child.blockSize.value)
hdf_count = int(child.numProcesses.value)
data_type = str(child.dataType.value)
# hdf_shape tuple represents the number of images in each file
hdf_shape = files_shape(generator.size, block_size, hdf_count)
# The first dimension alternating has no meaning. If any subsequent ones
# alternate then it will radically slow down the VDS creation and reading.
# We rely on a scanning.parts.UnrollingPart to
if any(dim.alternate for dim in generator.dimensions[1:]):
raise BadValueError(
"Snake scans are not supported as the VDS is not performant. You "
"can add a scanning.parts.UnrollingPart to the top level scan "
"block to unroll the scan into one long line"
)
alternates = None
files = [
os.path.join(vds_folder, f"{raw_name}_{dict_filewriter[i]}_{0:06d}.h5")
for i in range(hdf_count)
]
metafile = os.path.join(vds_folder, f"{raw_name}_meta.h5")
shape = (hdf_shape, image_height, image_width)
# prepare a vds for the image data
one_vds(
vds_folder,
vds_name,
files,
image_width,
image_height,
shape,
generator,
alternates,
block_size,
"data",
"data",
data_type.lower(),
)
chan_per_file = int(num_datasets / hdf_count)
file_counter = 0
with h5py.File(vds_path, "r+", libver="latest") as vds:
# MCA will have one dataset per channels shaped like [num_frames, 1]
for chan in range(num_datasets):
if chan >= (chan_per_file * (file_counter + 1)):
file_counter += 1
vds["raw/mca" + str(chan)] = h5py.ExternalLink(
files[file_counter], f"/mca_{chan}"
)
# UIDs will have one dataset per writer shaped like [num_frames]
file_counter = 0
for f in range(hdf_count):
vds["raw/uid" + str(f)] = h5py.ExternalLink(files[f], "/uid")
# Scalars will have one dataset per scalar (nine in total) with
# shape [num_frames, num_channels]
for scalar in range(9):
vds["raw/scalar_" + str(scalar)] = h5py.ExternalLink(
metafile, f"/scalar_{scalar}"
)
# DTC has only one dataset shaped like [num_frames, num_channels]
vds["raw/dtc"] = h5py.ExternalLink(metafile, "/dtc")
set_bases = ["/entry/xspress/"]
set_data = ["/data"]
def add_nexus_nodes(generator, vds_file_path):
"""Add in the additional information to make this into a standard nexus
format file:-
(a) create the standard structure under the 'entry' group with a
subgroup for each dataset. 'set_bases' lists the data sets we make here.
(b) save a dataset for each axis in each of the dimensions of the scan
representing the demand position at every point in the scan.
"""
# create the axes dimensions attribute, a comma separated list giving size
# of the axis dimensions padded with . for the detector dimensions and
# multidimensional dimensions
pad_dims = []
for d in generator.dimensions:
if len(d.axes) == 1:
pad_dims.append(f"{d.axes[0]}_set")
else:
pad_dims.append(".")
pad_dims += ["."] * 2 # assume a 2 dimensional detector
with h5py.File(vds_file_path, "r+", libver="latest") as vds:
vds.swmr_mode = True
for data, node in zip(set_data, set_bases):
# create a group for this entry
vds.require_group(node)
# points to the axis demand data sets
vds[node].attrs["axes"] = pad_dims
vds[node].attrs["NX_class"] = ["NXdata"]
# points to the detector dataset for this entry
vds[node].attrs["signal"] = data.split("/")[-1]
# a hard link from this entry 'signal' to the actual data
vds[node + data] = vds[data]
axis_sets = {}
# iterate the axes in each dimension of the generator to create the
# axis information nodes
for i, d in enumerate(generator.dimensions):
for axis in d.axes:
# add signal data dimension for axis
axis_indices = f"{axis}_set_indices"
vds[node].attrs[axis_indices] = i
# demand positions for axis
axis_set = f"{axis}_set"
if axis_sets.get(axis_set):
# link to the first entry's demand list
vds[node + axis_set] = axis_sets[axis_set]
else:
# create the demand list for the first entry only
axis_demands = d.get_positions(axis)
vds.create_dataset(node + axis_set, data=axis_demands)
vds[node + axis_set].attrs["units"] = generator.units[axis]
axis_sets[axis_set] = vds[node + axis_set]
vds["entry"].attrs["NX_class"] = ["NXentry"]
# We will set these attributes on the child block, so don't save them
[docs]@builtin.util.no_save("fileName", "filePath", "numCapture", "numFramesChunks")
class XspressWriterPart(builtin.parts.ChildPart):
"""Part for controlling an `xspress3_writer_block` in a Device"""
# Future for the start action
start_future: Optional[Future] = None
unique_id_offset: int = 0
# The HDF5 layout file we write to say where the datasets go
layout_filename: str = ""
exposure_time: float = 0.0
def __init__(
self,
name: APartName,
mri: AMri,
initial_visibility: AInitialVisibility = True,
sum_name: ASumName = "sum",
secondary_set: ASecondaryDataset = "sum",
num_pairs: ANumberPairs = 1,
num_datasets: ANumberDataset = 1,
) -> None:
self.sum_name = sum_name
self.secondary_set = secondary_set
self.num_pairs = num_pairs
self.num_datasets = num_datasets
super().__init__(name, mri, initial_visibility)
[docs] @add_call_types
def on_reset(self, context: scanning.hooks.AContext) -> None:
super().on_reset(context)
self.on_abort(context)
def setup(self, registrar: PartRegistrar) -> None:
super().setup(registrar)
# Tell the controller to expose some extra configure parameters
registrar.report(scanning.hooks.ConfigureHook.create_info(self.on_configure))
# Hooks
registrar.hook(scanning.hooks.ConfigureHook, self.on_configure)
registrar.hook(scanning.hooks.SeekHook, self.on_seek)
registrar.hook(scanning.hooks.RunHook, self.on_run)
registrar.hook(scanning.hooks.PostRunArmedHook, self.on_post_run_armed)
registrar.hook(scanning.hooks.PostRunReadyHook, self.on_post_run_ready)
registrar.hook(scanning.hooks.AbortHook, self.on_abort)
registrar.hook(scanning.hooks.PauseHook, self.on_pause)
[docs] @add_call_types
def on_pause(self, context: scanning.hooks.AContext) -> None:
raise NotImplementedError("Pause not implemented")
# Allow CamelCase as these parameters will be serialized
# noinspection PyPep8Naming
[docs] @add_call_types
def on_seek(
self,
context: scanning.hooks.AContext,
completed_steps: scanning.hooks.ACompletedSteps,
steps_to_do: scanning.hooks.AStepsToDo,
) -> None:
# This is rewinding or setting up for another batch, so the detector
# will skip to a uniqueID that has not been produced yet
self.unique_id_offset = completed_steps - self.done_when_reaches
self.done_when_reaches += steps_to_do
child = context.block_view(self.mri)
# Just reset the array counter_block
child.arrayCounter.put_value(0)
# Start a future waiting for the first array
self.array_future = child.when_value_matches_async(
"numCaptured", greater_than_zero
)
[docs] @add_call_types
def on_run(self, context: scanning.hooks.AContext) -> None:
context.wait_all_futures(self.array_future)
context.unsubscribe_all()
child = context.block_view(self.mri)
child.numCaptured.subscribe_value(self.update_completed_steps)
child.when_value_matches(
"numCaptured",
self.done_when_reaches,
event_timeout=self.exposure_time + FRAME_TIMEOUT,
)
[docs] @add_call_types
def on_post_run_armed(
self,
context: scanning.hooks.AContext,
steps_to_do: scanning.hooks.AStepsToDo,
) -> None:
self.done_when_reaches += self.num_pairs * steps_to_do
[docs] @add_call_types
def on_post_run_ready(self, context: scanning.hooks.AContext) -> None:
# If this is the last one, wait until the file is closed
context.wait_all_futures(self.start_future)
[docs] @add_call_types
def on_abort(self, context: scanning.hooks.AContext) -> None:
child = context.block_view(self.mri)
child.stop()
def update_completed_steps(self, value: int) -> None:
completed_steps = value + self.unique_id_offset
assert self.registrar, "No registrar"
self.registrar.report(scanning.infos.RunProgressInfo(completed_steps))