import asyncio
import json
import logging
import random
import time
import traceback
from collections.abc import Callable
from datetime import datetime
import requests
import sseclient
from pydantic import BaseModel
from smartem_backend.model.http_request import (
AcquisitionCreateRequest,
AgentInstructionAcknowledgement,
AtlasCreateRequest,
AtlasTileCreateRequest,
FoilHoleCreateRequest,
GridCreateRequest,
GridSquareCreateRequest,
GridSquarePositionRequest,
MicrographCreateRequest,
)
from smartem_backend.model.http_response import (
AcquisitionResponse,
AgentInstructionAcknowledgementResponse,
AtlasResponse,
AtlasTileGridSquarePositionResponse,
AtlasTileResponse,
FoilHoleResponse,
GridResponse,
GridSquareResponse,
MicrographResponse,
)
from smartem_common.entity_status import AcquisitionStatus, GridSquareStatus, GridStatus
from smartem_common.schemas import (
AcquisitionData,
AtlasData,
AtlasTileData,
AtlasTileGridSquarePositionData,
FoilHoleData,
GridData,
GridSquareData,
MicrographData,
)
# TODO look for a way to remove the extra bloat - conversion from EntityData type to EntityCreateRequest type
# if at all possible
[docs]
class EntityConverter:
"""
Handles conversions between EPU data model and API request/response models.
Separating this conversion logic keeps the main client code cleaner.
"""
[docs]
@staticmethod
def acquisition_to_request(entity: AcquisitionData) -> AcquisitionCreateRequest:
"""Convert EPU session data to acquisition request model"""
return AcquisitionCreateRequest(
uuid=entity.uuid,
# TODO check if natural `id` should also be included
name=entity.name,
start_time=entity.start_time,
storage_path=entity.storage_path,
atlas_path=entity.atlas_path,
clustering_mode=entity.clustering_mode,
clustering_radius=entity.clustering_radius,
instrument_model=entity.instrument.instrument_model if entity.instrument else None,
instrument_id=entity.instrument.instrument_id if entity.instrument else None,
computer_name=entity.instrument.computer_name if entity.instrument else None,
status=AcquisitionStatus.STARTED,
)
[docs]
@staticmethod
def grid_to_request(entity: GridData, lowmag: bool = False) -> GridCreateRequest:
"""Convert Grid data to grid request model"""
return GridCreateRequest(
uuid=entity.uuid,
status=GridStatus.NONE,
name=entity.acquisition_data.name if entity.acquisition_data else "Unknown",
acquisition_uuid=entity.acquisition_data.uuid,
data_dir=str(entity.data_dir) if entity.data_dir else None,
atlas_dir=str(entity.atlas_dir) if entity.atlas_dir else None,
lowmag=lowmag,
)
[docs]
@staticmethod
def gridsquare_to_request(entity: GridSquareData, lowmag: bool = False) -> GridSquareCreateRequest:
"""Convert GridSquareData to grid square request model"""
metadata = entity.metadata
manifest = entity.manifest
return GridSquareCreateRequest(
grid_uuid=entity.grid_uuid,
gridsquare_id=entity.gridsquare_id,
uuid=entity.uuid,
center_x=entity.center_x,
center_y=entity.center_y,
size_width=entity.size_width,
size_height=entity.size_height,
status=GridSquareStatus.NONE,
data_dir=str(entity.data_dir) if entity.data_dir else None,
atlas_node_id=metadata.atlas_node_id if metadata else None,
state=metadata.state if metadata else None,
rotation=metadata.rotation if metadata else None,
image_path=str(metadata.image_path) if metadata and metadata.image_path else None,
selected=metadata.selected if metadata else None,
unusable=metadata.unusable if metadata else None,
stage_position_x=metadata.stage_position.x if metadata and metadata.stage_position else None,
stage_position_y=metadata.stage_position.y if metadata and metadata.stage_position else None,
stage_position_z=metadata.stage_position.z if metadata and metadata.stage_position else None,
acquisition_datetime=manifest.acquisition_datetime if manifest else None,
defocus=manifest.defocus if manifest else None,
magnification=manifest.magnification if manifest else None,
pixel_size=manifest.pixel_size if manifest else None,
detector_name=manifest.detector_name if manifest else None,
applied_defocus=manifest.applied_defocus if manifest else None,
lowmag=lowmag,
)
[docs]
@staticmethod
def foilhole_to_request(entity: FoilHoleData) -> FoilHoleCreateRequest:
"""Convert FoilHoleData to foil hole request model"""
return FoilHoleCreateRequest(
uuid=entity.uuid,
foilhole_id=entity.id, # Changed from id=entity.id to foilhole_id=entity.id
gridsquare_id=entity.gridsquare_id,
gridsquare_uuid=entity.gridsquare_uuid,
center_x=entity.center_x,
center_y=entity.center_y,
quality=entity.quality,
rotation=entity.rotation,
size_width=entity.size_width,
size_height=entity.size_height,
x_location=entity.x_location,
y_location=entity.y_location,
x_stage_position=entity.x_stage_position,
y_stage_position=entity.y_stage_position,
diameter=entity.diameter,
is_near_grid_bar=entity.is_near_grid_bar,
)
[docs]
@staticmethod
def micrograph_to_request(entity: MicrographData) -> MicrographCreateRequest:
"""Convert MicrographData to micrograph request model"""
manifest = entity.manifest
return MicrographCreateRequest(
uuid=entity.uuid,
foilhole_uuid=entity.foilhole_uuid,
foilhole_id=entity.foilhole_id,
location_id=entity.location_id,
high_res_path=str(entity.high_res_path) if entity.high_res_path else None,
manifest_file=str(entity.manifest_file) if entity.manifest_file else None,
acquisition_datetime=manifest.acquisition_datetime if manifest else None,
defocus=manifest.defocus if manifest else None,
detector_name=manifest.detector_name if manifest else None,
energy_filter=manifest.energy_filter if manifest else None,
phase_plate=manifest.phase_plate if manifest else None,
image_size_x=manifest.image_size_x if manifest else None,
image_size_y=manifest.image_size_y if manifest else None,
binning_x=manifest.binning_x if manifest else None,
binning_y=manifest.binning_y if manifest else None,
)
[docs]
@staticmethod
def atlas_to_request(entity: AtlasData) -> AtlasCreateRequest:
"""Convert AtlasData to atlas request model"""
return AtlasCreateRequest(
uuid=entity.uuid,
atlas_id=entity.id,
grid_uuid=entity.grid_uuid,
name=entity.name,
storage_folder=entity.storage_folder,
acquisition_date=entity.acquisition_date,
tiles=[
AtlasTileCreateRequest(
atlas_uuid=entity.uuid,
uuid=t.uuid,
tile_id=t.id,
position_x=t.tile_position.position[0],
position_y=t.tile_position.position[1],
size_x=t.tile_position.size[0],
size_y=t.tile_position.size[1],
file_format=t.file_format,
base_filename=t.base_filename,
)
for t in entity.tiles
],
)
[docs]
@staticmethod
def atlas_tile_to_request(entity: AtlasTileData) -> AtlasTileCreateRequest:
"""Convert AtlasTileData to atlas tile request model"""
return AtlasTileCreateRequest(
atlas_uuid=entity.atlas_uuid,
uuid=entity.uuid,
tile_id=entity.id,
position_x=entity.tile_position.position[0],
position_y=entity.tile_position.position[1],
size_x=entity.tile_position.size[0],
size_y=entity.tile_position.size[1],
file_format=entity.file_format,
base_filename=entity.base_filename,
)
[docs]
@staticmethod
def gridsquare_position_to_request(entity: AtlasTileGridSquarePositionData) -> GridSquarePositionRequest:
"""Convert AtlasTileData to atlas tile request model"""
return GridSquarePositionRequest(
center_x=entity.position[0],
center_y=entity.position[1],
size_width=entity.size[0],
size_height=entity.size[1],
gridsquare_uuid=entity.gridsquare_uuid,
)
[docs]
class SmartEMAPIClient:
"""
SmartEM API client that provides synchronous HTTP interface.
This client handles all API communication with the SmartEM Core API,
provides data conversion between EPU data models and API request/response models,
and maintains a cache of entity IDs.
"""
def __init__(self, base_url: str, timeout: float = 10.0, logger=None):
"""
Initialize the SmartEM API client
Args:
base_url: Base URL for the API
timeout: Request timeout in seconds
logger: Optional custom logger instance
"""
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._session = requests.Session()
self._session.timeout = timeout
self._logger = logger or logging.getLogger(__name__)
# Configure logger if it's the default one
if not logger:
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
self._logger.addHandler(handler)
self._logger.setLevel(logging.INFO)
self._logger.info(f"Initialized SmartEM API client with base URL: {base_url}")
[docs]
def close(self) -> None:
"""Close the client connection"""
try:
self._session.close()
except Exception as e:
self._logger.error(f"Error closing session: {e}")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
# Generic API request methods
def _request(
self,
method: str,
endpoint: str,
request_model: BaseModel | dict | list[BaseModel] | None = None,
response_cls=None,
):
"""
Make a generic API request
Args:
method: HTTP method (get, post, put, delete)
endpoint: API endpoint path
request_model: Optional request data model
response_cls: Optional response class to parse the response
Returns:
Parsed response, list of responses, or None for delete operations
Raises:
requests.HTTPError: If the HTTP request returns an error status code
requests.RequestException: If there's a network error or timeout
ValueError: If there's an error parsing the response
Exception: For any other errors
"""
url = f"{self.base_url}/{endpoint}"
json_data = None
if request_model:
if hasattr(request_model, "model_dump"):
# It's a Pydantic model
json_data = request_model.model_dump(mode="json", exclude_none=True)
elif isinstance(request_model, list):
json_data = [m.model_dump(mode="json", exclude_none=True) for m in request_model]
else:
# It's already a dict, but might contain datetime objects
json_data = {k: v.isoformat() if isinstance(v, datetime) else v for k, v in request_model.items()}
self._logger.debug(f"Request data for {method} {url}: {json_data}")
try:
self._logger.debug(f"Making {method.upper()} request to {url}")
response = self._session.request(method, url, json=json_data)
response.raise_for_status()
# For delete operations, return None
if method.lower() == "delete":
self._logger.info(f"Successfully deleted resource at {url}")
return None
try:
data = response.json()
self._logger.debug(f"Response from {url}: {data}")
# Parse response if response_cls is provided
if response_cls:
try:
if isinstance(data, list):
return [response_cls.model_validate(item) for item in data]
else:
return response_cls.model_validate(data)
except Exception as e:
self._logger.error(f"Error validating response data from {url}: {e}")
self._logger.debug(f"Response data that failed validation: {data}")
raise ValueError(f"Invalid response data: {str(e)}") from None
return data
except json.JSONDecodeError as e:
self._logger.error(f"Could not parse JSON response from {url}: {e}")
self._logger.debug(f"Raw response: {response.text}")
raise ValueError(f"Invalid JSON response: {str(e)}") from None
except requests.HTTPError as e:
status_code = e.response.status_code
error_detail = None
# Try to extract error details from the response
try:
error_response = e.response.json()
error_detail = error_response.get("detail", str(e))
except Exception:
error_detail = e.response.text or str(e)
self._logger.error(f"HTTP {status_code} error for {method.upper()} {url}: {error_detail}")
raise
except requests.RequestException as e:
self._logger.error(f"Request error for {method.upper()} {url}: {e}")
self._logger.debug(f"Request error details: {traceback.format_exc()}")
raise
except Exception as e:
self._logger.error(f"Unexpected error making request to {url}: {e}")
self._logger.debug(f"Error details: {traceback.format_exc()}")
raise
# Entity-specific methods
# Status and Health
[docs]
def get_status(self) -> dict[str, object]:
"""Get API status information"""
return self._request("get", "status")
[docs]
def get_health(self) -> dict[str, object]:
"""Get API health check information"""
return self._request("get", "health")
# Acquisitions
[docs]
def get_acquisitions(self) -> list[AcquisitionResponse]:
"""Get all acquisitions"""
return self._request("get", "acquisitions", response_cls=AcquisitionResponse)
[docs]
def create_acquisition(self, acquisition: AcquisitionData) -> AcquisitionResponse:
"""Create a new acquisition"""
acquisition = EntityConverter.acquisition_to_request(acquisition)
response = self._request("post", "acquisitions", acquisition, AcquisitionResponse)
return response
[docs]
def get_acquisition(self, acquisition_uuid: str) -> AcquisitionResponse:
"""Get a single acquisition by ID"""
return self._request("get", f"acquisitions/{acquisition_uuid}", response_cls=AcquisitionResponse)
[docs]
def update_acquisition(self, acquisition: AcquisitionData) -> AcquisitionResponse:
"""Update an acquisition"""
acquisition = EntityConverter.acquisition_to_request(acquisition)
return self._request("put", f"acquisitions/{acquisition.uuid}", acquisition, AcquisitionResponse)
[docs]
def delete_acquisition(self, acquisition_uuid: str) -> None:
"""Delete an acquisition"""
return self._request("delete", f"acquisitions/{acquisition_uuid}")
# Grids
[docs]
def get_grids(self) -> list[GridResponse]:
"""Get all grids"""
return self._request("get", "grids", response_cls=GridResponse)
[docs]
def get_grid(self, grid_uuid: str) -> GridResponse:
"""Get a single grid by ID"""
return self._request("get", f"grids/{grid_uuid}", response_cls=GridResponse)
[docs]
def update_grid(self, grid: GridData) -> GridResponse:
"""Update a grid"""
grid = EntityConverter.grid_to_request(grid)
return self._request("put", f"grids/{grid.uuid}", grid, GridResponse)
[docs]
def delete_grid(self, grid_uuid: str) -> None:
"""Delete a grid"""
return self._request("delete", f"grids/{grid_uuid}")
[docs]
def get_acquisition_grids(self, acquisition_uuid: str) -> list[GridResponse]:
"""Get all grids for a specific acquisition"""
return self._request("get", f"acquisitions/{acquisition_uuid}/grids", response_cls=GridResponse)
[docs]
def create_acquisition_grid(self, grid: GridData) -> GridResponse:
"""Create a new grid for a specific acquisition"""
grid = EntityConverter.grid_to_request(grid)
response = self._request("post", f"acquisitions/{grid.acquisition_uuid}/grids", grid, GridResponse)
return response
def grid_registered(self, grid_uuid: str) -> bool:
return self._request("post", f"grids/{grid_uuid}/registered")
# Atlas
[docs]
def get_atlases(self) -> list[AtlasResponse]:
"""Get all atlases"""
return self._request("get", "atlases", response_cls=AtlasResponse)
[docs]
def get_atlas(self, atlas_uuid: str) -> AtlasResponse:
"""Get a single atlas by ID"""
return self._request("get", f"atlases/{atlas_uuid}", response_cls=AtlasResponse)
[docs]
def update_atlas(self, atlas: AtlasData) -> AtlasResponse:
"""Update an atlas"""
atlas = EntityConverter.atlas_to_request(atlas)
return self._request("put", f"atlases/{atlas.uuid}", atlas, AtlasResponse)
[docs]
def delete_atlas(self, atlas_uuid: str) -> None:
"""Delete an atlas"""
return self._request("delete", f"atlases/{atlas_uuid}")
[docs]
def get_grid_atlas(self, grid_uuid: str) -> AtlasResponse:
"""Get the atlas for a specific grid"""
return self._request("get", f"grids/{grid_uuid}/atlas", response_cls=AtlasResponse)
[docs]
def create_grid_atlas(self, atlas: AtlasData) -> AtlasResponse:
"""Create a new atlas for a grid"""
# Convert AtlasData to AtlasCreateRequest if needed
atlas = EntityConverter.atlas_to_request(atlas)
response = self._request("post", f"grids/{atlas.grid_uuid}/atlas", atlas, AtlasResponse)
return response
# Atlas Tiles
[docs]
def get_atlas_tiles(self) -> list[AtlasTileResponse]:
"""Get all atlas tiles"""
return self._request("get", "atlas-tiles", response_cls=AtlasTileResponse)
[docs]
def get_atlas_tile(self, tile_uuid: str) -> AtlasTileResponse:
"""Get a single atlas tile by ID"""
return self._request("get", f"atlas-tiles/{tile_uuid}", response_cls=AtlasTileResponse)
[docs]
def update_atlas_tile(self, tile: AtlasTileData) -> AtlasTileResponse:
"""Update an atlas tile"""
tile = EntityConverter.atlas_tile_to_request(tile)
return self._request("put", f"atlas-tiles/{tile.uuid}", tile, AtlasTileResponse)
[docs]
def delete_atlas_tile(self, tile_uuid: str) -> None:
"""Delete an atlas tile"""
return self._request("delete", f"atlas-tiles/{tile_uuid}")
[docs]
def get_atlas_tiles_by_atlas(self, atlas_uuid: str) -> list[AtlasTileResponse]:
"""Get all tiles for a specific atlas"""
return self._request("get", f"atlases/{atlas_uuid}/tiles", response_cls=AtlasTileResponse)
[docs]
def create_atlas_tile_for_atlas(self, tile: AtlasTileData) -> AtlasTileResponse:
"""Create a new tile for a specific atlas"""
tile = EntityConverter.atlas_tile_to_request(tile)
response = self._request("post", f"atlases/{tile.atlas_uuid}/tiles", tile, AtlasTileResponse)
return response
[docs]
def link_atlas_tile_and_gridsquare(
self, gridsquare_position: AtlasTileGridSquarePositionData
) -> AtlasTileGridSquarePositionResponse:
"""Link a grid square to a tile"""
tile_uuid = gridsquare_position.tile_uuid
gridsquare_uuid = gridsquare_position.gridsquare_uuid
gridsquare_position = EntityConverter.gridsquare_position_to_request(gridsquare_position)
response = self._request(
"post",
f"atlas-tiles/{tile_uuid}/gridsquares/{gridsquare_uuid}",
gridsquare_position,
AtlasTileGridSquarePositionResponse,
)
return response
[docs]
def link_atlas_tile_and_gridsquares(
self, gridsquare_positions: list[AtlasTileGridSquarePositionData]
) -> list[AtlasTileGridSquarePositionResponse]:
"""Link multiple grid squares to a tile"""
if not gridsquare_positions:
return []
assert len({pos.tile_uuid for pos in gridsquare_positions}) == 1
tile_uuid = gridsquare_positions[0].tile_uuid
gridsquare_positions = [EntityConverter.gridsquare_position_to_request(pos) for pos in gridsquare_positions]
response = self._request(
"post",
f"atlas-tiles/{tile_uuid}/gridsquares",
gridsquare_positions,
AtlasTileGridSquarePositionResponse,
)
return response
# GridSquares
[docs]
def get_gridsquares(self) -> list[GridSquareResponse]:
"""Get all grid squares"""
return self._request("get", "gridsquares", response_cls=GridSquareResponse)
[docs]
def get_gridsquare(self, gridsquare_uuid: str) -> GridSquareResponse:
"""Get a single grid square by ID"""
return self._request("get", f"gridsquares/{gridsquare_uuid}", response_cls=GridSquareResponse)
[docs]
def update_gridsquare(self, gridsquare: GridSquareData, lowmag: bool = False) -> GridSquareResponse:
"""Update a grid square"""
request_model = EntityConverter.gridsquare_to_request(gridsquare, lowmag=lowmag)
return self._request("put", f"gridsquares/{gridsquare.uuid}", request_model, GridSquareResponse)
[docs]
def delete_gridsquare(self, gridsquare_uuid: str) -> None:
"""Delete a grid square"""
return self._request("delete", f"gridsquares/{gridsquare_uuid}")
[docs]
def get_grid_gridsquares(self, grid_uuid: str) -> list[GridSquareResponse]:
"""Get all grid squares for a specific grid"""
return self._request("get", f"grids/{grid_uuid}/gridsquares", response_cls=GridSquareResponse)
[docs]
def create_grid_gridsquare(self, gridsquare: GridSquareData, lowmag: bool = False) -> GridSquareResponse:
"""Create a new grid square for a specific grid"""
# Convert GridSquareData to GridSquareCreateRequest if needed
gridsquare = EntityConverter.gridsquare_to_request(gridsquare, lowmag=lowmag)
response = self._request("post", f"grids/{gridsquare.grid_uuid}/gridsquares", gridsquare, GridSquareResponse)
return response
def gridsquare_registered(self, gridsquare_uuid: str, count: int | None = None) -> bool:
if count is None:
return self._request("post", f"gridsquares/{gridsquare_uuid}/registered")
return self._request("post", f"gridsquares/{gridsquare_uuid}/registered?count={count}")
# FoilHoles
[docs]
def get_foilholes(self) -> list[FoilHoleResponse]:
"""Get all foil holes"""
return self._request("get", "foilholes", response_cls=FoilHoleResponse)
[docs]
def get_foilhole(self, foilhole_uuid: str) -> FoilHoleResponse:
"""Get a single foil hole by ID"""
return self._request("get", f"foilholes/{foilhole_uuid}", response_cls=FoilHoleResponse)
[docs]
def update_foilhole(self, foilhole: FoilHoleData) -> FoilHoleResponse:
"""Update a foil hole"""
foilhole = EntityConverter.foilhole_to_request(foilhole)
return self._request("put", f"foilholes/{foilhole.uuid}", foilhole, FoilHoleResponse)
[docs]
def delete_foilhole(self, foilhole_uuid: str) -> None:
"""Delete a foil hole"""
return self._request("delete", f"foilholes/{foilhole_uuid}")
[docs]
def get_gridsquare_foilholes(self, gridsquare_uuid: str) -> list[FoilHoleResponse]:
"""Get all foil holes for a specific grid square"""
return self._request("get", f"gridsquares/{gridsquare_uuid}/foilholes", response_cls=FoilHoleResponse)
[docs]
def create_gridsquare_foilholes(
self, gridsquare_uuid: str, foilholes: list[FoilHoleData], allow_on_grid_bar: bool = False
) -> list[FoilHoleResponse]:
"""Create a new foil hole for a specific grid square"""
foilholes = [
EntityConverter.foilhole_to_request(fh)
for fh in foilholes
if (not fh.is_near_grid_bar or allow_on_grid_bar)
]
# this currently assumes all foil holes are on the same square
response = self._request("post", f"gridsquares/{gridsquare_uuid}/foilholes", foilholes, FoilHoleResponse)
return response
# Micrographs
[docs]
def get_micrographs(self) -> list[MicrographResponse]:
"""Get all micrographs"""
return self._request("get", "micrographs", response_cls=MicrographResponse)
[docs]
def get_micrograph(self, micrograph_uuid: str) -> MicrographResponse:
"""Get a single micrograph by ID"""
return self._request("get", f"micrographs/{micrograph_uuid}", response_cls=MicrographResponse)
[docs]
def update_micrograph(self, micrograph: MicrographData) -> MicrographResponse:
"""Update a micrograph"""
micrograph = EntityConverter.micrograph_to_request(micrograph)
return self._request("put", f"micrographs/{micrograph.uuid}", micrograph, MicrographResponse)
[docs]
def delete_micrograph(self, micrograph_id: str) -> None:
"""Delete a micrograph"""
return self._request("delete", f"micrographs/{micrograph_id}")
[docs]
def get_foilhole_micrographs(self, foilhole_id: str) -> list[MicrographResponse]:
"""Get all micrographs for a specific foil hole"""
return self._request("get", f"foilholes/{foilhole_id}/micrographs", response_cls=MicrographResponse)
[docs]
def create_foilhole_micrograph(self, micrograph: MicrographData) -> MicrographResponse:
"""Create a new micrograph for a specific foil hole"""
micrograph = EntityConverter.micrograph_to_request(micrograph)
response = self._request(
"post", f"foilholes/{micrograph.foilhole_uuid}/micrographs", micrograph, MicrographResponse
)
return response
# ============ Agent Communication Methods ============
[docs]
def acknowledge_instruction(
self, agent_id: str, session_id: str, instruction_id: str, acknowledgement: AgentInstructionAcknowledgement
) -> AgentInstructionAcknowledgementResponse:
"""Acknowledge an instruction from the agent"""
return self._request(
"post",
f"agent/{agent_id}/session/{session_id}/instructions/{instruction_id}/ack",
acknowledgement,
AgentInstructionAcknowledgementResponse,
)
[docs]
def get_active_connections(self) -> dict:
"""Get active agent connections (debug endpoint)"""
return self._request("get", "debug/agent-connections")
[docs]
def get_session_instructions(self, session_id: str) -> dict:
"""Get instructions for a session (debug endpoint)"""
return self._request("get", f"debug/session/{session_id}/instructions")
[docs]
class SSEAgentClient:
"""
SSE client for agents to receive real-time instructions from the backend.
This is separate from the main ApiClient as it handles long-lived connections.
"""
def __init__(
self,
base_url: str,
agent_id: str,
session_id: str,
timeout: int = 30,
max_retries: int = 10,
initial_retry_delay: float = 1.0,
max_retry_delay: float = 60.0,
):
"""
Initialize SSE client for agent communication
Args:
base_url: Base URL of the API server
agent_id: Unique identifier for this agent/microscope
session_id: Current microscopy session ID
timeout: Connection timeout in seconds
max_retries: Maximum number of reconnection attempts
initial_retry_delay: Initial delay between retries in seconds
max_retry_delay: Maximum delay between retries in seconds
"""
self.base_url = base_url.rstrip("/")
self.agent_id = agent_id
self.session_id = session_id
self.timeout = timeout
self.max_retries = max_retries
self.initial_retry_delay = initial_retry_delay
self.max_retry_delay = max_retry_delay
self.logger = logging.getLogger(f"SSEAgentClient-{agent_id}")
self._is_running = False
self._connection_id: str | None = None
self._stats = {
"total_connections": 0,
"successful_connections": 0,
"failed_connections": 0,
"instructions_received": 0,
"instructions_acknowledged": 0,
"last_connection_time": None,
"last_instruction_time": None,
}
[docs]
def stream_instructions(
self,
instruction_callback: Callable[[dict], None],
connection_callback: Callable[[dict], None] | None = None,
error_callback: Callable[[Exception], None] | None = None,
) -> None:
"""
Start streaming instructions via SSE (blocking)
Args:
instruction_callback: Called when an instruction is received
connection_callback: Called when connection events occur (optional)
error_callback: Called when errors occur (optional)
"""
stream_url = f"{self.base_url}/agent/{self.agent_id}/session/{self.session_id}/instructions/stream"
self.logger.info(f"Starting SSE stream for agent {self.agent_id}, session {self.session_id}")
self._is_running = True
self._stats["total_connections"] += 1
try:
response = requests.get(
stream_url, headers={"Accept": "text/event-stream"}, stream=True, timeout=self.timeout
)
response.raise_for_status()
self._stats["successful_connections"] += 1
self._stats["last_connection_time"] = datetime.now().isoformat()
client = sseclient.SSEClient(response)
for event in client.events():
if not self._is_running:
break
try:
data = json.loads(event.data)
event_type = data.get("type")
match event_type:
case "connection":
self._connection_id = data.get("connection_id")
self.logger.info(f"Connected with connection_id: {self._connection_id}")
if connection_callback:
connection_callback(data)
case "heartbeat":
self.logger.debug(f"Heartbeat received at {data.get('timestamp')}")
case "instruction":
self._stats["instructions_received"] += 1
self._stats["last_instruction_time"] = datetime.now().isoformat()
self.logger.info(
f"Instruction received: {data.get('instruction_id')} - {data.get('instruction_type')}"
)
instruction_callback(data)
case "error":
error_msg = data.get("message", "Unknown error")
error = ConnectionError(f"Server error: {error_msg}")
self.logger.error(f"Server error received: {error_msg}")
if error_callback:
error_callback(error)
break
case _:
self.logger.warning(f"Unknown event type: {event_type}")
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse SSE data: {e}")
if error_callback:
error_callback(e)
except Exception as e:
self.logger.error(f"Error processing SSE event: {e}")
if error_callback:
error_callback(e)
except requests.exceptions.RequestException as e:
self._stats["failed_connections"] += 1
self.logger.error(f"SSE connection error: {e}")
if error_callback:
error_callback(e)
except Exception as e:
self._stats["failed_connections"] += 1
self.logger.error(f"Unexpected SSE error: {e}")
if error_callback:
error_callback(e)
finally:
self._is_running = False
self.logger.info("SSE stream ended")
def _calculate_backoff_delay(self, retry_count: int) -> float:
"""Calculate exponential backoff delay with jitter."""
delay = min(self.initial_retry_delay * (2**retry_count), self.max_retry_delay)
# Add jitter (±25% of delay)
jitter = delay * 0.25 * (2 * random.random() - 1)
return max(0.1, delay + jitter)
[docs]
async def stream_instructions_async(
self,
instruction_callback: Callable[[dict], None],
connection_callback: Callable[[dict], None] | None = None,
error_callback: Callable[[Exception], None] | None = None,
auto_retry: bool = True,
) -> None:
"""
Start streaming instructions via SSE (async with auto-retry and exponential backoff)
Args:
instruction_callback: Called when an instruction is received
connection_callback: Called when connection events occur (optional)
error_callback: Called when errors occur (optional)
auto_retry: Whether to automatically retry on connection failures
"""
retry_count = 0
last_error: Exception | None = None
# Ensure we're running
self._is_running = True
while retry_count <= self.max_retries and self._is_running:
try:
self.logger.info(f"Starting SSE connection (attempt {retry_count + 1}/{self.max_retries + 1})")
# Run the synchronous streaming in a thread pool
await asyncio.get_event_loop().run_in_executor(
None, self.stream_instructions, instruction_callback, connection_callback, error_callback
)
# If we get here, the connection ended gracefully (user stopped it)
if not self._is_running:
self.logger.info("Connection stopped by user")
break
# If auto_retry is disabled, exit after one attempt
if not auto_retry:
break
except Exception as e:
last_error = e
retry_count += 1
self.logger.error(f"SSE connection failed (attempt {retry_count}/{self.max_retries + 1}): {e}")
if retry_count <= self.max_retries and auto_retry and self._is_running:
delay = self._calculate_backoff_delay(retry_count - 1)
self.logger.info(f"Retrying in {delay:.2f} seconds...")
await asyncio.sleep(delay)
else:
self.logger.error("Max retries reached or auto_retry disabled, giving up")
if error_callback and last_error:
error_callback(last_error)
break
[docs]
def acknowledge_instruction(
self,
instruction_id: str,
status: str,
result: str | None = None,
error_message: str | None = None,
processing_time_ms: int | None = None,
retry_count: int = 3,
) -> AgentInstructionAcknowledgementResponse:
"""
Acknowledge an instruction with retry logic
Args:
instruction_id: ID of the instruction to acknowledge
status: Status of acknowledgement ('received', 'processed', 'failed', 'declined')
result: Optional result message
error_message: Optional error message if status is 'failed'
processing_time_ms: Time taken to process the instruction in milliseconds
retry_count: Number of retry attempts for acknowledgement
"""
acknowledgement = AgentInstructionAcknowledgement(
status=status,
result=result,
error_message=error_message,
processing_time_ms=processing_time_ms,
processed_at=datetime.now(),
)
ack_url = f"{self.base_url}/agent/{self.agent_id}/session/{self.session_id}/instructions/{instruction_id}/ack"
last_error = None
for attempt in range(retry_count):
try:
response = requests.post(
ack_url,
json=acknowledgement.model_dump(mode="json"),
headers={"Content-Type": "application/json"},
timeout=self.timeout,
)
response.raise_for_status()
ack_response = AgentInstructionAcknowledgementResponse(**response.json())
self._stats["instructions_acknowledged"] += 1
self.logger.info(f"Successfully acknowledged instruction {instruction_id} with status {status}")
return ack_response
except requests.exceptions.RequestException as e:
last_error = e
if attempt < retry_count - 1:
delay = self._calculate_backoff_delay(attempt)
self.logger.warning(
f"Failed to acknowledge instruction {instruction_id} (attempt {attempt + 1}), "
f"retrying in {delay:.2f}s: {e}"
)
time.sleep(delay)
else:
self.logger.error(
f"Failed to acknowledge instruction {instruction_id} after {retry_count} attempts: {e}"
)
except Exception as e:
last_error = e
self.logger.error(f"Unexpected error acknowledging instruction {instruction_id}: {e}")
break
# If we get here, all retries failed
raise last_error if last_error else Exception("Unknown acknowledgement error")
[docs]
def get_stats(self) -> dict:
"""Get client connection and performance statistics."""
return {
**self._stats,
"agent_id": self.agent_id,
"session_id": self.session_id,
"connection_id": self._connection_id,
"is_running": self._is_running,
"max_retries": self.max_retries,
"success_rate": (self._stats["successful_connections"] / max(self._stats["total_connections"], 1)) * 100,
}
[docs]
def reset_stats(self) -> None:
"""Reset client statistics."""
self._stats = {
"total_connections": 0,
"successful_connections": 0,
"failed_connections": 0,
"instructions_received": 0,
"instructions_acknowledged": 0,
"last_connection_time": None,
"last_instruction_time": None,
}
[docs]
def is_connected(self) -> bool:
"""Check if the client is currently connected and streaming."""
return self._is_running and self._connection_id is not None
[docs]
def send_heartbeat(self, retry_count: int = 3) -> bool:
"""
Send a heartbeat to the backend to update connection health status
Args:
retry_count: Number of retry attempts for heartbeat
Returns:
bool: True if heartbeat was sent successfully, False otherwise
"""
heartbeat_url = f"{self.base_url}/agent/{self.agent_id}/session/{self.session_id}/heartbeat"
for attempt in range(retry_count):
try:
response = requests.post(
heartbeat_url,
headers={"Content-Type": "application/json"},
timeout=self.timeout,
)
response.raise_for_status()
heartbeat_response = response.json()
self.logger.debug(
f"Heartbeat sent successfully: {heartbeat_response.get('heartbeat_timestamp', 'unknown')}"
)
return True
except requests.exceptions.RequestException as e:
if attempt < retry_count - 1:
delay = self._calculate_backoff_delay(attempt)
self.logger.warning(
f"Failed to send heartbeat (attempt {attempt + 1}), retrying in {delay:.2f}s: {e}"
)
time.sleep(delay)
else:
self.logger.error(f"Failed to send heartbeat after {retry_count} attempts: {e}")
except Exception as e:
self.logger.error(f"Unexpected error sending heartbeat: {e}")
break
return False
[docs]
def stop(self):
"""Stop the SSE stream"""
self.logger.info("Stopping SSE stream...")
self._is_running = False