Source code for smartem_backend.api_client

import asyncio
import json
import logging
import random
import time
import traceback
from collections.abc import Callable
from datetime import datetime

import requests
import sseclient
from pydantic import BaseModel

from smartem_backend.model.http_request import (
    AcquisitionCreateRequest,
    AgentInstructionAcknowledgement,
    AtlasCreateRequest,
    AtlasTileCreateRequest,
    FoilHoleCreateRequest,
    GridCreateRequest,
    GridSquareCreateRequest,
    GridSquarePositionRequest,
    MicrographCreateRequest,
)
from smartem_backend.model.http_response import (
    AcquisitionResponse,
    AgentInstructionAcknowledgementResponse,
    AtlasResponse,
    AtlasTileGridSquarePositionResponse,
    AtlasTileResponse,
    FoilHoleResponse,
    GridResponse,
    GridSquareResponse,
    MicrographResponse,
)
from smartem_common.entity_status import AcquisitionStatus, GridSquareStatus, GridStatus
from smartem_common.schemas import (
    AcquisitionData,
    AtlasData,
    AtlasTileData,
    AtlasTileGridSquarePositionData,
    FoilHoleData,
    GridData,
    GridSquareData,
    MicrographData,
)

# TODO look for a way to remove the extra bloat - conversion from EntityData type to EntityCreateRequest type
#  if at all possible


[docs] class EntityConverter: """ Handles conversions between EPU data model and API request/response models. Separating this conversion logic keeps the main client code cleaner. """
[docs] @staticmethod def acquisition_to_request(entity: AcquisitionData) -> AcquisitionCreateRequest: """Convert EPU session data to acquisition request model""" return AcquisitionCreateRequest( uuid=entity.uuid, # TODO check if natural `id` should also be included name=entity.name, start_time=entity.start_time, storage_path=entity.storage_path, atlas_path=entity.atlas_path, clustering_mode=entity.clustering_mode, clustering_radius=entity.clustering_radius, instrument_model=entity.instrument.instrument_model if entity.instrument else None, instrument_id=entity.instrument.instrument_id if entity.instrument else None, computer_name=entity.instrument.computer_name if entity.instrument else None, status=AcquisitionStatus.STARTED, )
[docs] @staticmethod def grid_to_request(entity: GridData, lowmag: bool = False) -> GridCreateRequest: """Convert Grid data to grid request model""" return GridCreateRequest( uuid=entity.uuid, status=GridStatus.NONE, name=entity.acquisition_data.name if entity.acquisition_data else "Unknown", acquisition_uuid=entity.acquisition_data.uuid, data_dir=str(entity.data_dir) if entity.data_dir else None, atlas_dir=str(entity.atlas_dir) if entity.atlas_dir else None, lowmag=lowmag, )
[docs] @staticmethod def gridsquare_to_request(entity: GridSquareData, lowmag: bool = False) -> GridSquareCreateRequest: """Convert GridSquareData to grid square request model""" metadata = entity.metadata manifest = entity.manifest return GridSquareCreateRequest( grid_uuid=entity.grid_uuid, gridsquare_id=entity.gridsquare_id, uuid=entity.uuid, center_x=entity.center_x, center_y=entity.center_y, size_width=entity.size_width, size_height=entity.size_height, status=GridSquareStatus.NONE, data_dir=str(entity.data_dir) if entity.data_dir else None, atlas_node_id=metadata.atlas_node_id if metadata else None, state=metadata.state if metadata else None, rotation=metadata.rotation if metadata else None, image_path=str(metadata.image_path) if metadata and metadata.image_path else None, selected=metadata.selected if metadata else None, unusable=metadata.unusable if metadata else None, stage_position_x=metadata.stage_position.x if metadata and metadata.stage_position else None, stage_position_y=metadata.stage_position.y if metadata and metadata.stage_position else None, stage_position_z=metadata.stage_position.z if metadata and metadata.stage_position else None, acquisition_datetime=manifest.acquisition_datetime if manifest else None, defocus=manifest.defocus if manifest else None, magnification=manifest.magnification if manifest else None, pixel_size=manifest.pixel_size if manifest else None, detector_name=manifest.detector_name if manifest else None, applied_defocus=manifest.applied_defocus if manifest else None, lowmag=lowmag, )
[docs] @staticmethod def foilhole_to_request(entity: FoilHoleData) -> FoilHoleCreateRequest: """Convert FoilHoleData to foil hole request model""" return FoilHoleCreateRequest( uuid=entity.uuid, foilhole_id=entity.id, # Changed from id=entity.id to foilhole_id=entity.id gridsquare_id=entity.gridsquare_id, gridsquare_uuid=entity.gridsquare_uuid, center_x=entity.center_x, center_y=entity.center_y, quality=entity.quality, rotation=entity.rotation, size_width=entity.size_width, size_height=entity.size_height, x_location=entity.x_location, y_location=entity.y_location, x_stage_position=entity.x_stage_position, y_stage_position=entity.y_stage_position, diameter=entity.diameter, is_near_grid_bar=entity.is_near_grid_bar, )
[docs] @staticmethod def micrograph_to_request(entity: MicrographData) -> MicrographCreateRequest: """Convert MicrographData to micrograph request model""" manifest = entity.manifest return MicrographCreateRequest( uuid=entity.uuid, foilhole_uuid=entity.foilhole_uuid, foilhole_id=entity.foilhole_id, location_id=entity.location_id, high_res_path=str(entity.high_res_path) if entity.high_res_path else None, manifest_file=str(entity.manifest_file) if entity.manifest_file else None, acquisition_datetime=manifest.acquisition_datetime if manifest else None, defocus=manifest.defocus if manifest else None, detector_name=manifest.detector_name if manifest else None, energy_filter=manifest.energy_filter if manifest else None, phase_plate=manifest.phase_plate if manifest else None, image_size_x=manifest.image_size_x if manifest else None, image_size_y=manifest.image_size_y if manifest else None, binning_x=manifest.binning_x if manifest else None, binning_y=manifest.binning_y if manifest else None, )
[docs] @staticmethod def atlas_to_request(entity: AtlasData) -> AtlasCreateRequest: """Convert AtlasData to atlas request model""" return AtlasCreateRequest( uuid=entity.uuid, atlas_id=entity.id, grid_uuid=entity.grid_uuid, name=entity.name, storage_folder=entity.storage_folder, acquisition_date=entity.acquisition_date, tiles=[ AtlasTileCreateRequest( atlas_uuid=entity.uuid, uuid=t.uuid, tile_id=t.id, position_x=t.tile_position.position[0], position_y=t.tile_position.position[1], size_x=t.tile_position.size[0], size_y=t.tile_position.size[1], file_format=t.file_format, base_filename=t.base_filename, ) for t in entity.tiles ], )
[docs] @staticmethod def atlas_tile_to_request(entity: AtlasTileData) -> AtlasTileCreateRequest: """Convert AtlasTileData to atlas tile request model""" return AtlasTileCreateRequest( atlas_uuid=entity.atlas_uuid, uuid=entity.uuid, tile_id=entity.id, position_x=entity.tile_position.position[0], position_y=entity.tile_position.position[1], size_x=entity.tile_position.size[0], size_y=entity.tile_position.size[1], file_format=entity.file_format, base_filename=entity.base_filename, )
[docs] @staticmethod def gridsquare_position_to_request(entity: AtlasTileGridSquarePositionData) -> GridSquarePositionRequest: """Convert AtlasTileData to atlas tile request model""" return GridSquarePositionRequest( center_x=entity.position[0], center_y=entity.position[1], size_width=entity.size[0], size_height=entity.size[1], gridsquare_uuid=entity.gridsquare_uuid, )
[docs] class SmartEMAPIClient: """ SmartEM API client that provides synchronous HTTP interface. This client handles all API communication with the SmartEM Core API, provides data conversion between EPU data models and API request/response models, and maintains a cache of entity IDs. """ def __init__(self, base_url: str, timeout: float = 10.0, logger=None): """ Initialize the SmartEM API client Args: base_url: Base URL for the API timeout: Request timeout in seconds logger: Optional custom logger instance """ self.base_url = base_url.rstrip("/") self.timeout = timeout self._session = requests.Session() self._session.timeout = timeout self._logger = logger or logging.getLogger(__name__) # Configure logger if it's the default one if not logger: handler = logging.StreamHandler() formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) self._logger.addHandler(handler) self._logger.setLevel(logging.INFO) self._logger.info(f"Initialized SmartEM API client with base URL: {base_url}")
[docs] def close(self) -> None: """Close the client connection""" try: self._session.close() except Exception as e: self._logger.error(f"Error closing session: {e}")
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() # Generic API request methods def _request( self, method: str, endpoint: str, request_model: BaseModel | dict | list[BaseModel] | None = None, response_cls=None, ): """ Make a generic API request Args: method: HTTP method (get, post, put, delete) endpoint: API endpoint path request_model: Optional request data model response_cls: Optional response class to parse the response Returns: Parsed response, list of responses, or None for delete operations Raises: requests.HTTPError: If the HTTP request returns an error status code requests.RequestException: If there's a network error or timeout ValueError: If there's an error parsing the response Exception: For any other errors """ url = f"{self.base_url}/{endpoint}" json_data = None if request_model: if hasattr(request_model, "model_dump"): # It's a Pydantic model json_data = request_model.model_dump(mode="json", exclude_none=True) elif isinstance(request_model, list): json_data = [m.model_dump(mode="json", exclude_none=True) for m in request_model] else: # It's already a dict, but might contain datetime objects json_data = {k: v.isoformat() if isinstance(v, datetime) else v for k, v in request_model.items()} self._logger.debug(f"Request data for {method} {url}: {json_data}") try: self._logger.debug(f"Making {method.upper()} request to {url}") response = self._session.request(method, url, json=json_data) response.raise_for_status() # For delete operations, return None if method.lower() == "delete": self._logger.info(f"Successfully deleted resource at {url}") return None try: data = response.json() self._logger.debug(f"Response from {url}: {data}") # Parse response if response_cls is provided if response_cls: try: if isinstance(data, list): return [response_cls.model_validate(item) for item in data] else: return response_cls.model_validate(data) except Exception as e: self._logger.error(f"Error validating response data from {url}: {e}") self._logger.debug(f"Response data that failed validation: {data}") raise ValueError(f"Invalid response data: {str(e)}") from None return data except json.JSONDecodeError as e: self._logger.error(f"Could not parse JSON response from {url}: {e}") self._logger.debug(f"Raw response: {response.text}") raise ValueError(f"Invalid JSON response: {str(e)}") from None except requests.HTTPError as e: status_code = e.response.status_code error_detail = None # Try to extract error details from the response try: error_response = e.response.json() error_detail = error_response.get("detail", str(e)) except Exception: error_detail = e.response.text or str(e) self._logger.error(f"HTTP {status_code} error for {method.upper()} {url}: {error_detail}") raise except requests.RequestException as e: self._logger.error(f"Request error for {method.upper()} {url}: {e}") self._logger.debug(f"Request error details: {traceback.format_exc()}") raise except Exception as e: self._logger.error(f"Unexpected error making request to {url}: {e}") self._logger.debug(f"Error details: {traceback.format_exc()}") raise # Entity-specific methods # Status and Health
[docs] def get_status(self) -> dict[str, object]: """Get API status information""" return self._request("get", "status")
[docs] def get_health(self) -> dict[str, object]: """Get API health check information""" return self._request("get", "health")
# Acquisitions
[docs] def get_acquisitions(self) -> list[AcquisitionResponse]: """Get all acquisitions""" return self._request("get", "acquisitions", response_cls=AcquisitionResponse)
[docs] def create_acquisition(self, acquisition: AcquisitionData) -> AcquisitionResponse: """Create a new acquisition""" acquisition = EntityConverter.acquisition_to_request(acquisition) response = self._request("post", "acquisitions", acquisition, AcquisitionResponse) return response
[docs] def get_acquisition(self, acquisition_uuid: str) -> AcquisitionResponse: """Get a single acquisition by ID""" return self._request("get", f"acquisitions/{acquisition_uuid}", response_cls=AcquisitionResponse)
[docs] def update_acquisition(self, acquisition: AcquisitionData) -> AcquisitionResponse: """Update an acquisition""" acquisition = EntityConverter.acquisition_to_request(acquisition) return self._request("put", f"acquisitions/{acquisition.uuid}", acquisition, AcquisitionResponse)
[docs] def delete_acquisition(self, acquisition_uuid: str) -> None: """Delete an acquisition""" return self._request("delete", f"acquisitions/{acquisition_uuid}")
# Grids
[docs] def get_grids(self) -> list[GridResponse]: """Get all grids""" return self._request("get", "grids", response_cls=GridResponse)
[docs] def get_grid(self, grid_uuid: str) -> GridResponse: """Get a single grid by ID""" return self._request("get", f"grids/{grid_uuid}", response_cls=GridResponse)
[docs] def update_grid(self, grid: GridData) -> GridResponse: """Update a grid""" grid = EntityConverter.grid_to_request(grid) return self._request("put", f"grids/{grid.uuid}", grid, GridResponse)
[docs] def delete_grid(self, grid_uuid: str) -> None: """Delete a grid""" return self._request("delete", f"grids/{grid_uuid}")
[docs] def get_acquisition_grids(self, acquisition_uuid: str) -> list[GridResponse]: """Get all grids for a specific acquisition""" return self._request("get", f"acquisitions/{acquisition_uuid}/grids", response_cls=GridResponse)
[docs] def create_acquisition_grid(self, grid: GridData) -> GridResponse: """Create a new grid for a specific acquisition""" grid = EntityConverter.grid_to_request(grid) response = self._request("post", f"acquisitions/{grid.acquisition_uuid}/grids", grid, GridResponse) return response
def grid_registered(self, grid_uuid: str) -> bool: return self._request("post", f"grids/{grid_uuid}/registered") # Atlas
[docs] def get_atlases(self) -> list[AtlasResponse]: """Get all atlases""" return self._request("get", "atlases", response_cls=AtlasResponse)
[docs] def get_atlas(self, atlas_uuid: str) -> AtlasResponse: """Get a single atlas by ID""" return self._request("get", f"atlases/{atlas_uuid}", response_cls=AtlasResponse)
[docs] def update_atlas(self, atlas: AtlasData) -> AtlasResponse: """Update an atlas""" atlas = EntityConverter.atlas_to_request(atlas) return self._request("put", f"atlases/{atlas.uuid}", atlas, AtlasResponse)
[docs] def delete_atlas(self, atlas_uuid: str) -> None: """Delete an atlas""" return self._request("delete", f"atlases/{atlas_uuid}")
[docs] def get_grid_atlas(self, grid_uuid: str) -> AtlasResponse: """Get the atlas for a specific grid""" return self._request("get", f"grids/{grid_uuid}/atlas", response_cls=AtlasResponse)
[docs] def create_grid_atlas(self, atlas: AtlasData) -> AtlasResponse: """Create a new atlas for a grid""" # Convert AtlasData to AtlasCreateRequest if needed atlas = EntityConverter.atlas_to_request(atlas) response = self._request("post", f"grids/{atlas.grid_uuid}/atlas", atlas, AtlasResponse) return response
# Atlas Tiles
[docs] def get_atlas_tiles(self) -> list[AtlasTileResponse]: """Get all atlas tiles""" return self._request("get", "atlas-tiles", response_cls=AtlasTileResponse)
[docs] def get_atlas_tile(self, tile_uuid: str) -> AtlasTileResponse: """Get a single atlas tile by ID""" return self._request("get", f"atlas-tiles/{tile_uuid}", response_cls=AtlasTileResponse)
[docs] def update_atlas_tile(self, tile: AtlasTileData) -> AtlasTileResponse: """Update an atlas tile""" tile = EntityConverter.atlas_tile_to_request(tile) return self._request("put", f"atlas-tiles/{tile.uuid}", tile, AtlasTileResponse)
[docs] def delete_atlas_tile(self, tile_uuid: str) -> None: """Delete an atlas tile""" return self._request("delete", f"atlas-tiles/{tile_uuid}")
[docs] def get_atlas_tiles_by_atlas(self, atlas_uuid: str) -> list[AtlasTileResponse]: """Get all tiles for a specific atlas""" return self._request("get", f"atlases/{atlas_uuid}/tiles", response_cls=AtlasTileResponse)
[docs] def create_atlas_tile_for_atlas(self, tile: AtlasTileData) -> AtlasTileResponse: """Create a new tile for a specific atlas""" tile = EntityConverter.atlas_tile_to_request(tile) response = self._request("post", f"atlases/{tile.atlas_uuid}/tiles", tile, AtlasTileResponse) return response
# GridSquares
[docs] def get_gridsquares(self) -> list[GridSquareResponse]: """Get all grid squares""" return self._request("get", "gridsquares", response_cls=GridSquareResponse)
[docs] def get_gridsquare(self, gridsquare_uuid: str) -> GridSquareResponse: """Get a single grid square by ID""" return self._request("get", f"gridsquares/{gridsquare_uuid}", response_cls=GridSquareResponse)
[docs] def update_gridsquare(self, gridsquare: GridSquareData, lowmag: bool = False) -> GridSquareResponse: """Update a grid square""" request_model = EntityConverter.gridsquare_to_request(gridsquare, lowmag=lowmag) return self._request("put", f"gridsquares/{gridsquare.uuid}", request_model, GridSquareResponse)
[docs] def delete_gridsquare(self, gridsquare_uuid: str) -> None: """Delete a grid square""" return self._request("delete", f"gridsquares/{gridsquare_uuid}")
[docs] def get_grid_gridsquares(self, grid_uuid: str) -> list[GridSquareResponse]: """Get all grid squares for a specific grid""" return self._request("get", f"grids/{grid_uuid}/gridsquares", response_cls=GridSquareResponse)
[docs] def create_grid_gridsquare(self, gridsquare: GridSquareData, lowmag: bool = False) -> GridSquareResponse: """Create a new grid square for a specific grid""" # Convert GridSquareData to GridSquareCreateRequest if needed gridsquare = EntityConverter.gridsquare_to_request(gridsquare, lowmag=lowmag) response = self._request("post", f"grids/{gridsquare.grid_uuid}/gridsquares", gridsquare, GridSquareResponse) return response
def gridsquare_registered(self, gridsquare_uuid: str, count: int | None = None) -> bool: if count is None: return self._request("post", f"gridsquares/{gridsquare_uuid}/registered") return self._request("post", f"gridsquares/{gridsquare_uuid}/registered?count={count}") # FoilHoles
[docs] def get_foilholes(self) -> list[FoilHoleResponse]: """Get all foil holes""" return self._request("get", "foilholes", response_cls=FoilHoleResponse)
[docs] def get_foilhole(self, foilhole_uuid: str) -> FoilHoleResponse: """Get a single foil hole by ID""" return self._request("get", f"foilholes/{foilhole_uuid}", response_cls=FoilHoleResponse)
[docs] def update_foilhole(self, foilhole: FoilHoleData) -> FoilHoleResponse: """Update a foil hole""" foilhole = EntityConverter.foilhole_to_request(foilhole) return self._request("put", f"foilholes/{foilhole.uuid}", foilhole, FoilHoleResponse)
[docs] def delete_foilhole(self, foilhole_uuid: str) -> None: """Delete a foil hole""" return self._request("delete", f"foilholes/{foilhole_uuid}")
[docs] def get_gridsquare_foilholes(self, gridsquare_uuid: str) -> list[FoilHoleResponse]: """Get all foil holes for a specific grid square""" return self._request("get", f"gridsquares/{gridsquare_uuid}/foilholes", response_cls=FoilHoleResponse)
[docs] def create_gridsquare_foilholes( self, gridsquare_uuid: str, foilholes: list[FoilHoleData], allow_on_grid_bar: bool = False ) -> list[FoilHoleResponse]: """Create a new foil hole for a specific grid square""" foilholes = [ EntityConverter.foilhole_to_request(fh) for fh in foilholes if (not fh.is_near_grid_bar or allow_on_grid_bar) ] # this currently assumes all foil holes are on the same square response = self._request("post", f"gridsquares/{gridsquare_uuid}/foilholes", foilholes, FoilHoleResponse) return response
# Micrographs
[docs] def get_micrographs(self) -> list[MicrographResponse]: """Get all micrographs""" return self._request("get", "micrographs", response_cls=MicrographResponse)
[docs] def get_micrograph(self, micrograph_uuid: str) -> MicrographResponse: """Get a single micrograph by ID""" return self._request("get", f"micrographs/{micrograph_uuid}", response_cls=MicrographResponse)
[docs] def update_micrograph(self, micrograph: MicrographData) -> MicrographResponse: """Update a micrograph""" micrograph = EntityConverter.micrograph_to_request(micrograph) return self._request("put", f"micrographs/{micrograph.uuid}", micrograph, MicrographResponse)
[docs] def delete_micrograph(self, micrograph_id: str) -> None: """Delete a micrograph""" return self._request("delete", f"micrographs/{micrograph_id}")
[docs] def get_foilhole_micrographs(self, foilhole_id: str) -> list[MicrographResponse]: """Get all micrographs for a specific foil hole""" return self._request("get", f"foilholes/{foilhole_id}/micrographs", response_cls=MicrographResponse)
[docs] def create_foilhole_micrograph(self, micrograph: MicrographData) -> MicrographResponse: """Create a new micrograph for a specific foil hole""" micrograph = EntityConverter.micrograph_to_request(micrograph) response = self._request( "post", f"foilholes/{micrograph.foilhole_uuid}/micrographs", micrograph, MicrographResponse ) return response
# ============ Agent Communication Methods ============
[docs] def acknowledge_instruction( self, agent_id: str, session_id: str, instruction_id: str, acknowledgement: AgentInstructionAcknowledgement ) -> AgentInstructionAcknowledgementResponse: """Acknowledge an instruction from the agent""" return self._request( "post", f"agent/{agent_id}/session/{session_id}/instructions/{instruction_id}/ack", acknowledgement, AgentInstructionAcknowledgementResponse, )
[docs] def get_active_connections(self) -> dict: """Get active agent connections (debug endpoint)""" return self._request("get", "debug/agent-connections")
[docs] def get_session_instructions(self, session_id: str) -> dict: """Get instructions for a session (debug endpoint)""" return self._request("get", f"debug/session/{session_id}/instructions")
[docs] class SSEAgentClient: """ SSE client for agents to receive real-time instructions from the backend. This is separate from the main ApiClient as it handles long-lived connections. """ def __init__( self, base_url: str, agent_id: str, session_id: str, timeout: int = 30, max_retries: int = 10, initial_retry_delay: float = 1.0, max_retry_delay: float = 60.0, ): """ Initialize SSE client for agent communication Args: base_url: Base URL of the API server agent_id: Unique identifier for this agent/microscope session_id: Current microscopy session ID timeout: Connection timeout in seconds max_retries: Maximum number of reconnection attempts initial_retry_delay: Initial delay between retries in seconds max_retry_delay: Maximum delay between retries in seconds """ self.base_url = base_url.rstrip("/") self.agent_id = agent_id self.session_id = session_id self.timeout = timeout self.max_retries = max_retries self.initial_retry_delay = initial_retry_delay self.max_retry_delay = max_retry_delay self.logger = logging.getLogger(f"SSEAgentClient-{agent_id}") self._is_running = False self._connection_id: str | None = None self._stats = { "total_connections": 0, "successful_connections": 0, "failed_connections": 0, "instructions_received": 0, "instructions_acknowledged": 0, "last_connection_time": None, "last_instruction_time": None, }
[docs] def stream_instructions( self, instruction_callback: Callable[[dict], None], connection_callback: Callable[[dict], None] | None = None, error_callback: Callable[[Exception], None] | None = None, ) -> None: """ Start streaming instructions via SSE (blocking) Args: instruction_callback: Called when an instruction is received connection_callback: Called when connection events occur (optional) error_callback: Called when errors occur (optional) """ stream_url = f"{self.base_url}/agent/{self.agent_id}/session/{self.session_id}/instructions/stream" self.logger.info(f"Starting SSE stream for agent {self.agent_id}, session {self.session_id}") self._is_running = True self._stats["total_connections"] += 1 try: response = requests.get( stream_url, headers={"Accept": "text/event-stream"}, stream=True, timeout=self.timeout ) response.raise_for_status() self._stats["successful_connections"] += 1 self._stats["last_connection_time"] = datetime.now().isoformat() client = sseclient.SSEClient(response) for event in client.events(): if not self._is_running: break try: data = json.loads(event.data) event_type = data.get("type") match event_type: case "connection": self._connection_id = data.get("connection_id") self.logger.info(f"Connected with connection_id: {self._connection_id}") if connection_callback: connection_callback(data) case "heartbeat": self.logger.debug(f"Heartbeat received at {data.get('timestamp')}") case "instruction": self._stats["instructions_received"] += 1 self._stats["last_instruction_time"] = datetime.now().isoformat() self.logger.info( f"Instruction received: {data.get('instruction_id')} - {data.get('instruction_type')}" ) instruction_callback(data) case "error": error_msg = data.get("message", "Unknown error") error = ConnectionError(f"Server error: {error_msg}") self.logger.error(f"Server error received: {error_msg}") if error_callback: error_callback(error) break case _: self.logger.warning(f"Unknown event type: {event_type}") except json.JSONDecodeError as e: self.logger.error(f"Failed to parse SSE data: {e}") if error_callback: error_callback(e) except Exception as e: self.logger.error(f"Error processing SSE event: {e}") if error_callback: error_callback(e) except requests.exceptions.RequestException as e: self._stats["failed_connections"] += 1 self.logger.error(f"SSE connection error: {e}") if error_callback: error_callback(e) except Exception as e: self._stats["failed_connections"] += 1 self.logger.error(f"Unexpected SSE error: {e}") if error_callback: error_callback(e) finally: self._is_running = False self.logger.info("SSE stream ended")
def _calculate_backoff_delay(self, retry_count: int) -> float: """Calculate exponential backoff delay with jitter.""" delay = min(self.initial_retry_delay * (2**retry_count), self.max_retry_delay) # Add jitter (±25% of delay) jitter = delay * 0.25 * (2 * random.random() - 1) return max(0.1, delay + jitter)
[docs] async def stream_instructions_async( self, instruction_callback: Callable[[dict], None], connection_callback: Callable[[dict], None] | None = None, error_callback: Callable[[Exception], None] | None = None, auto_retry: bool = True, ) -> None: """ Start streaming instructions via SSE (async with auto-retry and exponential backoff) Args: instruction_callback: Called when an instruction is received connection_callback: Called when connection events occur (optional) error_callback: Called when errors occur (optional) auto_retry: Whether to automatically retry on connection failures """ retry_count = 0 last_error: Exception | None = None # Ensure we're running self._is_running = True while retry_count <= self.max_retries and self._is_running: try: self.logger.info(f"Starting SSE connection (attempt {retry_count + 1}/{self.max_retries + 1})") # Run the synchronous streaming in a thread pool await asyncio.get_event_loop().run_in_executor( None, self.stream_instructions, instruction_callback, connection_callback, error_callback ) # If we get here, the connection ended gracefully (user stopped it) if not self._is_running: self.logger.info("Connection stopped by user") break # If auto_retry is disabled, exit after one attempt if not auto_retry: break except Exception as e: last_error = e retry_count += 1 self.logger.error(f"SSE connection failed (attempt {retry_count}/{self.max_retries + 1}): {e}") if retry_count <= self.max_retries and auto_retry and self._is_running: delay = self._calculate_backoff_delay(retry_count - 1) self.logger.info(f"Retrying in {delay:.2f} seconds...") await asyncio.sleep(delay) else: self.logger.error("Max retries reached or auto_retry disabled, giving up") if error_callback and last_error: error_callback(last_error) break
[docs] def acknowledge_instruction( self, instruction_id: str, status: str, result: str | None = None, error_message: str | None = None, processing_time_ms: int | None = None, retry_count: int = 3, ) -> AgentInstructionAcknowledgementResponse: """ Acknowledge an instruction with retry logic Args: instruction_id: ID of the instruction to acknowledge status: Status of acknowledgement ('received', 'processed', 'failed', 'declined') result: Optional result message error_message: Optional error message if status is 'failed' processing_time_ms: Time taken to process the instruction in milliseconds retry_count: Number of retry attempts for acknowledgement """ acknowledgement = AgentInstructionAcknowledgement( status=status, result=result, error_message=error_message, processing_time_ms=processing_time_ms, processed_at=datetime.now(), ) ack_url = f"{self.base_url}/agent/{self.agent_id}/session/{self.session_id}/instructions/{instruction_id}/ack" last_error = None for attempt in range(retry_count): try: response = requests.post( ack_url, json=acknowledgement.model_dump(mode="json"), headers={"Content-Type": "application/json"}, timeout=self.timeout, ) response.raise_for_status() ack_response = AgentInstructionAcknowledgementResponse(**response.json()) self._stats["instructions_acknowledged"] += 1 self.logger.info(f"Successfully acknowledged instruction {instruction_id} with status {status}") return ack_response except requests.exceptions.RequestException as e: last_error = e if attempt < retry_count - 1: delay = self._calculate_backoff_delay(attempt) self.logger.warning( f"Failed to acknowledge instruction {instruction_id} (attempt {attempt + 1}), " f"retrying in {delay:.2f}s: {e}" ) time.sleep(delay) else: self.logger.error( f"Failed to acknowledge instruction {instruction_id} after {retry_count} attempts: {e}" ) except Exception as e: last_error = e self.logger.error(f"Unexpected error acknowledging instruction {instruction_id}: {e}") break # If we get here, all retries failed raise last_error if last_error else Exception("Unknown acknowledgement error")
[docs] def get_stats(self) -> dict: """Get client connection and performance statistics.""" return { **self._stats, "agent_id": self.agent_id, "session_id": self.session_id, "connection_id": self._connection_id, "is_running": self._is_running, "max_retries": self.max_retries, "success_rate": (self._stats["successful_connections"] / max(self._stats["total_connections"], 1)) * 100, }
[docs] def reset_stats(self) -> None: """Reset client statistics.""" self._stats = { "total_connections": 0, "successful_connections": 0, "failed_connections": 0, "instructions_received": 0, "instructions_acknowledged": 0, "last_connection_time": None, "last_instruction_time": None, }
[docs] def is_connected(self) -> bool: """Check if the client is currently connected and streaming.""" return self._is_running and self._connection_id is not None
[docs] def send_heartbeat(self, retry_count: int = 3) -> bool: """ Send a heartbeat to the backend to update connection health status Args: retry_count: Number of retry attempts for heartbeat Returns: bool: True if heartbeat was sent successfully, False otherwise """ heartbeat_url = f"{self.base_url}/agent/{self.agent_id}/session/{self.session_id}/heartbeat" for attempt in range(retry_count): try: response = requests.post( heartbeat_url, headers={"Content-Type": "application/json"}, timeout=self.timeout, ) response.raise_for_status() heartbeat_response = response.json() self.logger.debug( f"Heartbeat sent successfully: {heartbeat_response.get('heartbeat_timestamp', 'unknown')}" ) return True except requests.exceptions.RequestException as e: if attempt < retry_count - 1: delay = self._calculate_backoff_delay(attempt) self.logger.warning( f"Failed to send heartbeat (attempt {attempt + 1}), retrying in {delay:.2f}s: {e}" ) time.sleep(delay) else: self.logger.error(f"Failed to send heartbeat after {retry_count} attempts: {e}") except Exception as e: self.logger.error(f"Unexpected error sending heartbeat: {e}") break return False
[docs] def stop(self): """Stop the SSE stream""" self.logger.info("Stopping SSE stream...") self._is_running = False