import random
from typing import Annotated
import typer
from sqlalchemy.engine import Engine
from sqlmodel import Session, select
from smartem_backend.model.database import FoilHole, Grid, GridSquare, QualityPrediction, QualityPredictionModel
from smartem_backend.utils import get_db_engine, logger
DEFAULT_PREDICTION_RANGE = (0.0, 1.0)
def generate_random_predictions(
    model_name: str,
    grid_uuid: str | None = None,
    random_range: tuple[float, float] = (0, 1),
    level: Annotated[
        str, typer.Option(help="Magnification level at which to generate predictions. Options are 'hole' or 'square'")
    ] = "hole",
    engine: Engine = None,
) -> None:
    if level not in ("hole", "square"):
        raise ValueError(f"Level must be set to either 'hole' or 'square' not {level}")
    if engine is None:
        engine = get_db_engine()
    with Session(engine) as sess:
        if grid_uuid is None:
            grid = sess.exec(select(Grid)).first()
            grid_uuid = grid.uuid
        if level == "hole":
            holes = sess.exec(
                select(FoilHole, GridSquare)
                .where(FoilHole.gridsquare_uuid == GridSquare.uuid)
                .where(GridSquare.grid_uuid == grid_uuid)
            ).all()
            preds = [
                QualityPrediction(
                    value=random.uniform(random_range[0], random_range[1]),
                    prediction_model_name=model_name,
                    foilhole_uuid=h[0].uuid,
                )
                for h in holes
            ]
            sess.add_all(preds)
            sess.commit()
        else:
            squares = sess.exec(select(GridSquare).where(GridSquare.grid_uuid == grid_uuid)).all()
            preds = [
                QualityPrediction(
                    value=random.uniform(random_range[0], random_range[1]),
                    prediction_model_name=model_name,
                    gridsquare_uuid=s.uuid,
                )
                for s in squares
            ]
            sess.add_all(preds)
            sess.commit()
    return None
[docs]
def generate_predictions_for_gridsquare(
    gridsquare_uuid: str, grid_uuid: str | None = None, engine: Engine = None
) -> None:
    """
    Generate random predictions for a single gridsquare using all available models.
    Args:
        gridsquare_uuid: UUID of the gridsquare to generate predictions for
        grid_uuid: UUID of the parent grid (optional, will be looked up if not provided)
        engine: Optional database engine (uses singleton if not provided)
    """
    if engine is None:
        engine = get_db_engine()
    with Session(engine) as sess:
        # Get all available prediction models
        models = sess.exec(select(QualityPredictionModel)).all()
        if not models:
            logger.warning(f"No prediction models found to generate predictions for gridsquare {gridsquare_uuid}")
            return
        # If grid_uuid not provided, look it up
        if grid_uuid is None:
            gridsquare = sess.get(GridSquare, gridsquare_uuid)
            if gridsquare is None:
                logger.error(f"GridSquare {gridsquare_uuid} not found in database")
                return
            grid_uuid = gridsquare.grid_uuid
        # Generate predictions for each model
        predictions = []
        for model in models:
            # Check if prediction already exists for this gridsquare-model combination
            existing_prediction = sess.exec(
                select(QualityPrediction).where(
                    QualityPrediction.gridsquare_uuid == gridsquare_uuid,
                    QualityPrediction.prediction_model_name == model.name,
                )
            ).first()
            if existing_prediction is None:
                prediction = QualityPrediction(
                    value=random.uniform(DEFAULT_PREDICTION_RANGE[0], DEFAULT_PREDICTION_RANGE[1]),
                    prediction_model_name=model.name,
                    gridsquare_uuid=gridsquare_uuid,
                )
                predictions.append(prediction)
                logger.info(
                    f"Generated prediction {prediction.value:.3f} for model '{model.name}' "
                    f"on gridsquare {gridsquare_uuid}"
                )
            else:
                logger.debug(f"Prediction already exists for model '{model.name}' on gridsquare {gridsquare_uuid}")
        if predictions:
            sess.add_all(predictions)
            sess.commit()
            logger.info(f"Generated {len(predictions)} predictions for gridsquare {gridsquare_uuid}") 
[docs]
def generate_predictions_for_foilhole(
    foilhole_uuid: str, gridsquare_uuid: str | None = None, engine: Engine = None
) -> None:
    """
    Generate random predictions for a single foilhole using all available models.
    Args:
        foilhole_uuid: UUID of the foilhole to generate predictions for
        gridsquare_uuid: UUID of the parent gridsquare (optional, for validation if provided)
        engine: Optional database engine (uses singleton if not provided)
    """
    if engine is None:
        engine = get_db_engine()
    with Session(engine) as sess:
        # Get all available prediction models
        models = sess.exec(select(QualityPredictionModel)).all()
        if not models:
            logger.warning(f"No prediction models found to generate predictions for foilhole {foilhole_uuid}")
            return
        # Optional validation: if gridsquare_uuid provided, verify the foilhole belongs to it
        if gridsquare_uuid is not None:
            foilhole = sess.get(FoilHole, foilhole_uuid)
            if foilhole is None:
                logger.error(f"FoilHole {foilhole_uuid} not found in database")
                return
            if foilhole.gridsquare_uuid != gridsquare_uuid:
                logger.error(
                    f"FoilHole {foilhole_uuid} belongs to gridsquare {foilhole.gridsquare_uuid}, not {gridsquare_uuid}"
                )
                return
        # Generate predictions for each model
        predictions = []
        for model in models:
            # Check if prediction already exists for this foilhole-model combination
            existing_prediction = sess.exec(
                select(QualityPrediction).where(
                    QualityPrediction.foilhole_uuid == foilhole_uuid,
                    QualityPrediction.prediction_model_name == model.name,
                )
            ).first()
            if existing_prediction is None:
                prediction = QualityPrediction(
                    value=random.uniform(DEFAULT_PREDICTION_RANGE[0], DEFAULT_PREDICTION_RANGE[1]),
                    prediction_model_name=model.name,
                    foilhole_uuid=foilhole_uuid,
                )
                predictions.append(prediction)
                logger.info(
                    f"Generated prediction {prediction.value:.3f} for model '{model.name}' on foilhole {foilhole_uuid}"
                )
            else:
                logger.debug(f"Prediction already exists for model '{model.name}' on foilhole {foilhole_uuid}")
        if predictions:
            sess.add_all(predictions)
            sess.commit()
            logger.info(f"Generated {len(predictions)} predictions for foilhole {foilhole_uuid}") 
def run() -> None:
    typer.run(generate_random_predictions)
    return None