Source code for smartem_backend.cli.initialise_prediction_model_weights
import typer
from sqlalchemy.engine import Engine
from sqlmodel import Session, select
from smartem_backend.model.database import Grid, QualityPredictionModel, QualityPredictionModelWeight
from smartem_backend.utils import get_db_engine, logger
[docs]
def initialise_all_models_for_grid(grid_uuid: str, engine: Engine = None) -> None:
"""
Initialise prediction model weights for all available models for a specific grid.
Args:
grid_uuid: UUID of the grid to initialise weights for
engine: Optional database engine (uses singleton if not provided)
"""
if engine is None:
engine = get_db_engine()
with Session(engine) as sess:
# Get all available prediction models
models = sess.exec(select(QualityPredictionModel)).all()
if not models:
logger.warning(f"No prediction models found to initialise for grid {grid_uuid}")
return
# Initialise weights for each model
default_weight = 1 / len(models)
for model in models:
# Check if weight already exists for this grid-model combination
existing_weight = sess.exec(
select(QualityPredictionModelWeight).where(
QualityPredictionModelWeight.grid_uuid == grid_uuid,
QualityPredictionModelWeight.prediction_model_name == model.name,
)
).first()
if existing_weight is None:
weight_entry = QualityPredictionModelWeight(
grid_uuid=grid_uuid, prediction_model_name=model.name, weight=default_weight
)
sess.add(weight_entry)
logger.info(f"Initialised weight {default_weight} for model '{model.name}' on grid {grid_uuid}")
else:
logger.debug(f"Weight already exists for model '{model.name}' on grid {grid_uuid}")
sess.commit()
[docs]
def initialise_prediction_model_for_grid(
name: str, weight: float, grid_uuid: str | None = None, engine: Engine = None
) -> None:
"""
Initialise a single prediction model weight for a grid (CLI interface).
Args:
name: Prediction model name
weight: Weight value to assign
grid_uuid: Grid UUID (if None, uses first available grid)
engine: Optional database engine (uses singleton if not provided)
"""
if engine is None:
engine = get_db_engine()
with Session(engine) as sess:
if grid_uuid is None:
grid = sess.exec(select(Grid)).first()
if grid is None:
logger.error("No grids found in database")
return
grid_uuid = grid.uuid
sess.add(QualityPredictionModelWeight(grid_uuid=grid_uuid, prediction_model_name=name, weight=weight))
sess.commit()
logger.info(f"Initialised weight {weight} for model '{name}' on grid {grid_uuid}")
return None
def run() -> None:
typer.run(initialise_prediction_model_for_grid)
return None