Source code for smartem_backend.utils

import json
import logging
import os
from collections.abc import Callable
from datetime import datetime
from typing import Any

import pika
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel
from sqlalchemy.engine import Engine
from sqlmodel import create_engine

from smartem_backend.log_manager import LogConfig, LogManager
from smartem_backend.model.mq_event import MessageQueueEventType


def load_conf() -> dict | None:
    config_path = os.getenv("SMARTEM_BACKEND_CONFIG") or os.path.join(os.path.dirname(__file__), "appconfig.yml")
    try:
        with open(config_path) as f:
            conf = yaml.safe_load(f)
        return conf
    except FileNotFoundError:
        # Use basic logging since logger might not be configured yet
        print(f"Warning: Configuration file not found at {config_path}")
    except yaml.YAMLError as e:
        print(f"Warning: Error parsing YAML file: {e}")
    except Exception as e:
        print(f"Warning: An unexpected error occurred: {e}")
    return None


[docs] def get_log_file_path(conf: dict | None = None) -> str | None: """ Get the log file path with validation and fallback handling. Args: conf: Configuration dictionary (if None, will load from config file) Returns: str | None: Valid log file path or None for test environments """ # Don't create file handlers in test environment to avoid resource warnings if "pytest" in os.environ.get("_", "") or "PYTEST_CURRENT_TEST" in os.environ: return None if conf is None: conf = load_conf() # Get log file path from config or use default log_file = conf.get("app", {}).get("log_file", "smartem_backend-core.log") if conf else "smartem_backend-core.log" # Validate and ensure directory exists if log_file: log_dir = os.path.dirname(os.path.abspath(log_file)) try: # Create directory if it doesn't exist if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) # Test if we can write to the directory test_file = os.path.join(log_dir or ".", ".write_test") with open(test_file, "w") as f: f.write("test") os.remove(test_file) return log_file except (OSError, PermissionError) as e: print(f"Warning: Cannot write to log directory {log_dir}: {e}") print("Falling back to current directory") return "smartem_backend-core.log" return "smartem_backend-core.log"
[docs] def setup_logger(level: int = logging.INFO, conf: dict | None = None): """ Set up logger with consolidated configuration logic. Args: level: Logging level (default: INFO) conf: Configuration dictionary (if None, will load from config file) Returns: Configured logger instance """ file_path = get_log_file_path(conf) return LogManager.get_instance("smartem_backend").configure( LogConfig( level=level, console=True, file_path=file_path, ) )
logger = setup_logger() # Global singleton engine instance _db_engine: Engine | None = None
[docs] def setup_postgres_connection(echo=False, force_new=False) -> Engine: """ Get or create a singleton database engine with connection pooling. Args: echo: Enable SQL logging force_new: Force creation of new engine (for testing) Returns: SQLAlchemy Engine instance """ global _db_engine # Return existing engine unless forced to create new one if _db_engine is not None and not force_new: return _db_engine load_dotenv(override=False) # Don't override existing env vars as these might be coming from k8s required_env_vars = ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_HOST", "POSTGRES_PORT", "POSTGRES_DB"] env_vars = {} for key in required_env_vars: value = os.getenv(key) if value is None: logger.error(f"Error: Required environment variable '{key}' is not set") exit(1) env_vars[key] = value # Load database configuration from appconfig.yml with defaults config = load_conf() db_config = config.get("database", {}) if config else {} pool_size = db_config.get("pool_size", 10) max_overflow = db_config.get("max_overflow", 20) pool_timeout = db_config.get("pool_timeout", 30) pool_recycle = db_config.get("pool_recycle", 3600) pool_pre_ping = db_config.get("pool_pre_ping", True) # Create engine with connection pooling _db_engine = create_engine( f"postgresql+psycopg2://{env_vars['POSTGRES_USER']}:{env_vars['POSTGRES_PASSWORD']}@" f"{env_vars['POSTGRES_HOST']}:{env_vars['POSTGRES_PORT']}/{env_vars['POSTGRES_DB']}", echo=echo, # Connection pool settings from config pool_size=pool_size, # Number of connections to maintain in pool max_overflow=max_overflow, # Additional connections beyond pool_size pool_timeout=pool_timeout, # Seconds to wait for connection from pool pool_recycle=pool_recycle, # Seconds after which connection is recreated pool_pre_ping=pool_pre_ping, # Validate connections before use ) logger.info(f"Created database engine with pool_size={pool_size}, max_overflow={max_overflow}") return _db_engine
[docs] def get_db_engine() -> Engine: """ Get the singleton database engine. Creates it if it doesn't exist. Returns: SQLAlchemy Engine instance """ return setup_postgres_connection()
[docs] class RabbitMQConnection: """ Base class for RabbitMQ connection management """ def __init__( self, connection_params: dict[str, Any] | None = None, exchange: str = "", queue: str = "smartem_backend" ): """ Initialize RabbitMQ connection Args: connection_params: Dictionary with RabbitMQ connection parameters. If None, load from environment variables exchange: Exchange name to use (default is direct exchange "") queue: Queue name to use """ self.connection_params = connection_params or self._load_connection_params_from_env() self.exchange = exchange self.queue = queue self._connection = None self._channel = None @staticmethod def _load_connection_params_from_env() -> dict[str, Any]: """ Load RabbitMQ connection parameters from environment variables Returns: dict: Connection parameters """ load_dotenv(override=False) # Don't override existing env vars as these might be coming from k8s required_env_vars = ["RABBITMQ_HOST", "RABBITMQ_PORT", "RABBITMQ_USER", "RABBITMQ_PASSWORD"] for key in required_env_vars: if os.getenv(key) is None: logger.error(f"Error: Required environment variable '{key}' is not set") exit(1) return { "host": os.getenv("RABBITMQ_HOST", "localhost"), "port": int(os.getenv("RABBITMQ_PORT", "5672")), "virtual_host": os.getenv("RABBITMQ_VHOST", "/"), "credentials": { "username": os.getenv("RABBITMQ_USER", "guest"), "password": os.getenv("RABBITMQ_PASSWORD", "guest"), }, }
[docs] def connect(self) -> None: """Establish connection to RabbitMQ server""" if self._connection is None or self._connection.is_closed: try: # Extract credentials from connection_params to create proper credential object if "credentials" in self.connection_params and isinstance(self.connection_params["credentials"], dict): credentials_dict = self.connection_params["credentials"] credentials = pika.PlainCredentials( username=credentials_dict["username"], password=credentials_dict["password"] ) # Create new connection params dict with proper credentials object connection_params = {**self.connection_params, "credentials": credentials} else: connection_params = self.connection_params self._connection = pika.BlockingConnection(pika.ConnectionParameters(**connection_params)) self._channel = self._connection.channel() # Declare queue with durable=True to ensure it survives broker restarts self._channel.queue_declare(queue=self.queue, durable=True) logger.info(f"Connected to RabbitMQ and declared queue '{self.queue}'") except Exception as e: logger.error(f"Failed to connect to RabbitMQ: {str(e)}") raise
[docs] def close(self) -> None: """Close the connection to RabbitMQ""" if self._connection and self._connection.is_open: self._connection.close() self._connection = None self._channel = None logger.info("Closed connection to RabbitMQ")
[docs] def channel(self): """ Get the channel object Returns: The current channel object """ if self._channel is None: self.connect() return self._channel
def __enter__(self): self.connect() return self def __exit__(self, exc_type, exc_val, exc_tb): self.close()
[docs] class RabbitMQPublisher(RabbitMQConnection): """ Publisher class for sending messages to RabbitMQ """
[docs] def publish_event(self, event_type: MessageQueueEventType, payload: BaseModel | dict[str, Any]) -> bool: """ Publish an event to RabbitMQ Args: event_type: Type of event from EventType enum payload: Event payload, either as Pydantic model or dictionary Returns: bool: True if message was published successfully """ try: self.connect() # Convert Pydantic model to dict if needed if isinstance(payload, BaseModel): payload_dict = json.loads(payload.json()) else: payload_dict = payload # Create message with event_type and payload message = {"event_type": event_type.value, **payload_dict} # Use a custom encoder for json.dumps that handles datetime objects class DateTimeEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() return super().default(obj) # Convert message to JSON using the custom encoder message_json = json.dumps(message, cls=DateTimeEncoder) # Publish message with delivery_mode=2 (persistent) self._channel.basic_publish( exchange=self.exchange, routing_key=self.queue, body=message_json, properties=pika.BasicProperties( delivery_mode=2, # Make message persistent content_type="application/json", ), ) return True except Exception as e: logger.error(f"Failed to publish {event_type.value} event: {str(e)}") return False
[docs] class RabbitMQConsumer(RabbitMQConnection): """ Consumer class for receiving messages from RabbitMQ """
[docs] def consume(self, callback: Callable, prefetch_count: int = 1) -> None: """ Start consuming messages from the queue Args: callback: Callback function to process messages prefetch_count: Maximum number of unacknowledged messages (default: 1) """ try: self.connect() self._channel.basic_qos(prefetch_count=prefetch_count) self._channel.basic_consume(queue=self.queue, on_message_callback=callback) logger.info(f"Consumer started, listening on queue '{self.queue}'") self._channel.start_consuming() except KeyboardInterrupt: logger.info("Consumer stopped by user") self.stop_consuming() except Exception as e: logger.error(f"Error in consumer: {e}") raise
[docs] def stop_consuming(self) -> None: """Stop consuming messages""" if self._channel and self._channel.is_open: self._channel.stop_consuming() logger.info("Stopped consuming messages")
[docs] def setup_rabbitmq(queue_name=None, exchange=None): """ Create RabbitMQ publisher and consumer instances using configuration settings Args: queue_name: Optional queue name override (if None, load from config) exchange: Optional exchange name override (if None, use default "") Returns: tuple: (RabbitMQPublisher instance, RabbitMQConsumer instance) """ # Load config to get queue_name and routing_key config = load_conf() if not queue_name and config and "rabbitmq" in config: queue_name = config["rabbitmq"]["queue_name"] routing_key = config["rabbitmq"]["routing_key"] else: # Default to "smartem_backend" if config not available queue_name = queue_name or "smartem_backend" routing_key = queue_name # Use queue_name as routing_key by default exchange = exchange or "" # Default to direct exchange if not specified # Create publisher and consumer with the same connection settings publisher = RabbitMQPublisher(connection_params=None, exchange=exchange, queue=routing_key) consumer = RabbitMQConsumer(connection_params=None, exchange=exchange, queue=queue_name) return publisher, consumer
# Load application configuration. TODO do once and share the singleton conf with rest of codebase app_config = load_conf() # Create RabbitMQ connections (available as singletons throughout the application) rmq_publisher, rmq_consumer = setup_rabbitmq( queue_name=app_config["rabbitmq"]["queue_name"] if app_config and "rabbitmq" in app_config else None, exchange=os.getenv("RABBITMQ_EXCHANGE") or "", )