#!/usr/bin/env python3
"""
Agent Connection Management Service
Handles connection health monitoring, instruction expiration, and cleanup
for the SmartEM backend-to-agent communication system.
"""
import asyncio
import logging
import uuid
from datetime import datetime, timedelta
from typing import Any
from sqlalchemy import and_
from sqlmodel import Session
from smartem_backend.model.database import AgentConnection, AgentInstruction, AgentSession
from smartem_backend.mq_publisher import publish_agent_instruction_expired
from smartem_backend.utils import get_db_engine
[docs]
class AgentConnectionManager:
    """
    Manages agent connections, health monitoring, and instruction lifecycle.
    Responsibilities:
    - Monitor connection health and detect stale connections
    - Handle instruction expiration and retry logic
    - Clean up inactive sessions and connections
    - Provide connection statistics and monitoring
    """
    def __init__(self, db_engine=None, check_interval: int = 30):
        """
        Initialize the connection manager.
        Args:
            db_engine: Database engine (defaults to global engine)
            check_interval: How often to run cleanup tasks (seconds)
        """
        self.db_engine = db_engine or get_db_engine()
        self.check_interval = check_interval
        self.logger = logging.getLogger("AgentConnectionManager")
        self._running = False
        self._task: asyncio.Task | None = None
[docs]
    async def start(self):
        """Start the connection manager background tasks."""
        if self._running:
            self.logger.warning("Connection manager already running")
            return
        self._running = True
        self.logger.info(f"Starting agent connection manager (check interval: {self.check_interval}s)")
        # Start background monitoring task
        self._task = asyncio.create_task(self._monitoring_loop()) 
[docs]
    async def stop(self):
        """Stop the connection manager background tasks."""
        if not self._running:
            return
        self.logger.info("Stopping agent connection manager")
        self._running = False
        if self._task:
            self._task.cancel()
            try:
                await self._task
            except asyncio.CancelledError:
                pass 
    async def _monitoring_loop(self):
        """Main monitoring loop that runs cleanup tasks periodically."""
        while self._running:
            try:
                await self._run_cleanup_tasks()
                await asyncio.sleep(self.check_interval)
            except asyncio.CancelledError:
                break
            except Exception as e:
                self.logger.error(f"Error in monitoring loop: {e}")
                await asyncio.sleep(self.check_interval)
    async def _run_cleanup_tasks(self):
        """Run all cleanup and monitoring tasks."""
        await asyncio.gather(
            self._cleanup_stale_connections(),
            self._handle_expired_instructions(),
            self._update_session_activity(),
            return_exceptions=True,
        )
    async def _cleanup_stale_connections(self):
        """Clean up connections that haven't received heartbeats recently."""
        try:
            with Session(self.db_engine) as session:
                # Consider connections stale if no heartbeat for 2 minutes
                stale_threshold = datetime.now() - timedelta(minutes=2)
                stale_connections = (
                    session.query(AgentConnection)
                    .filter(
                        and_(AgentConnection.status == "active", AgentConnection.last_heartbeat_at < stale_threshold)
                    )
                    .all()
                )
                for conn in stale_connections:
                    self.logger.info(f"Marking stale connection {conn.connection_id} as closed")
                    conn.status = "closed"
                    conn.closed_at = datetime.now()
                    conn.close_reason = "stale_connection"
                if stale_connections:
                    session.commit()
                    self.logger.info(f"Cleaned up {len(stale_connections)} stale connections")
        except Exception as e:
            self.logger.error(f"Error cleaning up stale connections: {e}")
    async def _handle_expired_instructions(self):
        """Handle instructions that have expired and need retry or failure logic."""
        try:
            with Session(self.db_engine) as session:
                # Find instructions that have expired
                now = datetime.now()
                expired_instructions = (
                    session.query(AgentInstruction)
                    .filter(
                        and_(
                            AgentInstruction.status.in_(["pending", "sent"]),
                            AgentInstruction.expires_at.is_not(None),
                            AgentInstruction.expires_at <= now,
                        )
                    )
                    .all()
                )
                for instruction in expired_instructions:
                    self.logger.info(
                        f"Processing expired instruction {instruction.instruction_id} "
                        f"(retry {instruction.retry_count}/{instruction.max_retries})"
                    )
                    # Publish expiration event for processing
                    success = publish_agent_instruction_expired(
                        instruction_id=instruction.instruction_id,
                        session_id=instruction.session_id,
                        agent_id=instruction.agent_id,
                        expires_at=instruction.expires_at,
                        retry_count=instruction.retry_count + 1,
                    )
                    if success:
                        # Update retry count immediately
                        instruction.retry_count += 1
                        if instruction.retry_count >= instruction.max_retries:
                            instruction.status = "expired"
                            self.logger.info(
                                f"Instruction {instruction.instruction_id} marked as expired "
                                f"after {instruction.retry_count} retries"
                            )
                        else:
                            # Reset for retry with new expiration time
                            instruction.status = "pending"
                            instruction.expires_at = now + timedelta(minutes=5)  # 5-minute retry window
                            self.logger.info(
                                f"Instruction {instruction.instruction_id} reset for retry "
                                f"({instruction.retry_count}/{instruction.max_retries})"
                            )
                    else:
                        self.logger.error(
                            f"Failed to publish expiration event for instruction {instruction.instruction_id}"
                        )
                if expired_instructions:
                    session.commit()
                    self.logger.info(f"Processed {len(expired_instructions)} expired instructions")
        except Exception as e:
            self.logger.error(f"Error handling expired instructions: {e}")
    async def _update_session_activity(self):
        """Update session activity and mark inactive sessions."""
        try:
            with Session(self.db_engine) as session:
                # Mark sessions inactive if no activity for 1 hour
                inactive_threshold = datetime.now() - timedelta(hours=1)
                inactive_sessions = (
                    session.query(AgentSession)
                    .filter(and_(AgentSession.status == "active", AgentSession.last_activity_at < inactive_threshold))
                    .all()
                )
                for agent_session in inactive_sessions:
                    self.logger.info(f"Marking session {agent_session.session_id} as inactive")
                    agent_session.status = "inactive"
                if inactive_sessions:
                    session.commit()
                    self.logger.info(f"Marked {len(inactive_sessions)} sessions as inactive")
        except Exception as e:
            self.logger.error(f"Error updating session activity: {e}")
[docs]
    def get_connection_stats(self) -> dict[str, Any]:
        """Get current connection and session statistics."""
        try:
            with Session(self.db_engine) as session:
                # Active connections
                active_connections = session.query(AgentConnection).filter(AgentConnection.status == "active").count()
                # Active sessions
                active_sessions = session.query(AgentSession).filter(AgentSession.status == "active").count()
                # Pending instructions
                pending_instructions = (
                    session.query(AgentInstruction).filter(AgentInstruction.status == "pending").count()
                )
                # Sent but not acknowledged instructions
                sent_instructions = session.query(AgentInstruction).filter(AgentInstruction.status == "sent").count()
                return {
                    "active_connections": active_connections,
                    "active_sessions": active_sessions,
                    "pending_instructions": pending_instructions,
                    "sent_instructions": sent_instructions,
                    "total_instructions_pending": pending_instructions + sent_instructions,
                    "timestamp": datetime.now().isoformat(),
                }
        except Exception as e:
            self.logger.error(f"Error getting connection stats: {e}")
            return {
                "error": str(e),
                "timestamp": datetime.now().isoformat(),
            } 
[docs]
    def create_session(
        self,
        agent_id: str,
        acquisition_uuid: str | None = None,
        name: str | None = None,
        description: str | None = None,
        experimental_parameters: dict | None = None,
    ) -> str:
        """
        Create a new agent session.
        Args:
            agent_id: Unique identifier for the agent
            acquisition_uuid: Associated acquisition UUID (optional)
            name: Session name (optional)
            description: Session description (optional)
            experimental_parameters: Experimental parameters (optional)
        Returns:
            str: Created session ID
        """
        session_id = str(uuid.uuid4())
        try:
            with Session(self.db_engine) as db_session:
                agent_session = AgentSession(
                    session_id=session_id,
                    agent_id=agent_id,
                    acquisition_uuid=acquisition_uuid,
                    name=name or f"Session-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
                    description=description,
                    experimental_parameters=experimental_parameters or {},
                    status="active",
                    created_at=datetime.now(),
                    last_activity_at=datetime.now(),
                )
                db_session.add(agent_session)
                db_session.commit()
                self.logger.info(f"Created session {session_id} for agent {agent_id}")
                return session_id
        except Exception as e:
            self.logger.error(f"Error creating session: {e}")
            raise 
[docs]
    def close_session(self, session_id: str) -> bool:
        """
        Close an agent session and clean up associated connections.
        Args:
            session_id: Session ID to close
        Returns:
            bool: True if session was closed successfully
        """
        try:
            with Session(self.db_engine) as session:
                # Mark session as ended
                agent_session = session.query(AgentSession).filter(AgentSession.session_id == session_id).first()
                if agent_session:
                    agent_session.status = "ended"
                    agent_session.ended_at = datetime.now()
                # Close associated connections
                connections = session.query(AgentConnection).filter(AgentConnection.session_id == session_id).all()
                for conn in connections:
                    if conn.status == "active":
                        conn.status = "closed"
                        conn.closed_at = datetime.now()
                        conn.close_reason = "session_closed"
                session.commit()
                self.logger.info(f"Closed session {session_id} and {len(connections)} connections")
                return True
        except Exception as e:
            self.logger.error(f"Error closing session {session_id}: {e}")
            return False 
 
# Global connection manager instance
_connection_manager: AgentConnectionManager | None = None
[docs]
def get_connection_manager() -> AgentConnectionManager:
    """Get the global connection manager instance."""
    global _connection_manager
    if _connection_manager is None:
        _connection_manager = AgentConnectionManager()
    return _connection_manager 
[docs]
async def start_connection_manager():
    """Start the global connection manager."""
    manager = get_connection_manager()
    await manager.start() 
[docs]
async def stop_connection_manager():
    """Stop the global connection manager."""
    manager = get_connection_manager()
    await manager.stop()