Source code for httomo_backends.methods_database.query

from dataclasses import dataclass
from enum import Enum
from importlib import import_module
from types import ModuleType
from typing import Callable, List, Literal, Optional, Tuple
from pathlib import Path
import numpy as np

import yaml

YAML_DIR = Path(__file__).parent / "packages/"


[docs] class Pattern(Enum): """Enum for the different slicing-orientations/"patterns" that tomographic data can have. """ projection = 0 sinogram = 1 all = 2
[docs] @dataclass(frozen=True) class GpuMemoryRequirement: multiplier: Optional[float] = 1.0 method: Literal["direct", "module"] = "direct"
[docs] class MethodsDatabaseQuery: """ Implements the `MethodQuery` protocol from `httomo`. """ def __init__(self, module_path: str, method_name: str): self.module_path = module_path self.method_name = method_name def _get_method_info(self, attr: str): """Get the information about the given method associated with `attr` that is stored in the relevant YAML file in `httomo/methods_database/packages/` Parameters ---------- module_path : str The full module path of the method, including the top-level package name. Ie, `httomolib.misc.images.save_to_images`. method_name : str The name of the method function. attr : str The name of the piece of information about the method being requested (for example, "pattern"). Returns ------- The requested piece of information about the method. """ method_path = f"{self.module_path}.{self.method_name}" split_method_path = method_path.split(".") package_name = split_method_path[0] # open the library file for the package ext_package_path = "" if package_name != "httomo": ext_package_path = f"backends/{package_name}/" else: ext_package_path = "" yaml_info_path = Path(YAML_DIR, str(ext_package_path), f"{package_name}.yaml") if not yaml_info_path.exists(): err_str = f"The YAML file {yaml_info_path} doesn't exist." raise FileNotFoundError(err_str) with open(yaml_info_path, "r") as f: info = yaml.safe_load(f) for key in split_method_path[1:]: try: info = info[key] except KeyError: raise KeyError(f"The key {key} is not present ({method_path})") try: return info[attr] except KeyError: raise KeyError(f"The attribute {attr} is not present on {method_path}")
[docs] def get_pattern(self) -> Pattern: p = self._get_method_info("pattern") assert p in ["projection", "sinogram", "all"], ( f"The pattern {p} that is listed for the method " f"{self.module_path}.{self.method_name} is invalid." ) if p == "projection": return Pattern.projection if p == "sinogram": return Pattern.sinogram return Pattern.all
[docs] def get_output_dims_change(self) -> bool: p = self._get_method_info("output_dims_change") return bool(p)
[docs] def get_implementation(self) -> Literal["cpu", "gpu", "gpu_cupy"]: p = self._get_method_info("implementation") assert p in [ "gpu", "gpu_cupy", "cpu", ], f"The implementation arch {p} listed for method {self.module_path}.{self.method_name} is invalid" return p
[docs] def save_result_default(self) -> bool: return self._get_method_info("save_result_default")
[docs] def padding(self) -> bool: return self._get_method_info("padding")
[docs] def get_memory_gpu_params( self, ) -> Optional[GpuMemoryRequirement]: p = self._get_method_info("memory_gpu") if p is None or p == "None": return None if type(p) == list: # convert to dict first d: dict = dict() for item in p: d |= item else: d = p return GpuMemoryRequirement(multiplier=d["multiplier"], method=d["method"])
[docs] def calculate_memory_bytes( self, non_slice_dims_shape: Tuple[int, int], dtype: np.dtype, **kwargs ) -> Tuple[int, int]: smodule = self._import_supporting_funcs_module() module_mem: Callable = getattr( smodule, "_calc_memory_bytes_" + self.method_name ) memory_bytes: Tuple[int, int] = module_mem( non_slice_dims_shape, dtype, **kwargs ) return memory_bytes
[docs] def calculate_output_dims( self, non_slice_dims_shape: Tuple[int, int], **kwargs ) -> Tuple[int, int]: smodule = self._import_supporting_funcs_module() module_mem: Callable = getattr(smodule, "_calc_output_dim_" + self.method_name) return module_mem(non_slice_dims_shape, **kwargs)
[docs] def calculate_padding(self, **kwargs) -> Tuple[int, int]: smodule = self._import_supporting_funcs_module() module_pad: Callable = getattr(smodule, "_calc_padding_" + self.method_name) return module_pad(**kwargs)
def _import_supporting_funcs_module(self) -> ModuleType: module_mem_path = "httomo_backends.methods_database.packages.backends." path = self.module_path.split(".") path.insert(1, "supporting_funcs") module_mem_path += ".".join(path) return import_module(module_mem_path)
[docs] def swap_dims_on_output(self) -> bool: return self.module_path.startswith("tomopy.recon")
[docs] class MethodDatabaseRepository: """ Implements the `MethodRepository` protocol from `httomo`. """
[docs] def query(self, module_path: str, method_name: str) -> MethodsDatabaseQuery: return MethodsDatabaseQuery(module_path, method_name)