Source code for focal.commands.test_commands

"""This module define the logic for testing each model."""

from focal.mlflow_utils import (
    log_classifier_test_results,
    log_regressor_test_results,
    log_rl_test,
)
from focal.prediction_testing import (
    TensionPredictor,
    TestPredictions,
    TestTensionPredictions,
)
from focal.rl_pipeline import TestAgent
from focal.xgboost_pipeline import XGBoostPredictor

from .base_command import BaseCommand


[docs] class TestCNN(BaseCommand): """Test CNN model performance.""" def _execute_command(self, config) -> None: if config.cnn_mode == "bad_good": tester = TestPredictions( model_path=config.model_path, csv_path=config.csv_path, scalar_path=config.feature_scaler_path, img_folder=config.img_folder, image_only=False, backbone=config.backbone, angle_threshold=config.angle_threshold, diameter_threshold=config.diameter_threshold, threshold=config.classification_threshold, ) elif config.cnn_mode == "tension": tester = TestTensionPredictions( cnn_model_path=config.model_path, tension_model_path=config.tension_model_path, csv_path=config.csv_path, scaler_path=config.feature_scaler_path, img_folder=config.img_folder, tension_threshold=config.tension_threshold, image_only=False, ) true_labels, pred_labels, predictions = tester.gather_predictions() if true_labels is not None: confusion_Path = tester.display_confusion_matrix( true_labels, pred_labels, model_path=config.model_path ) tester.display_classification_report( true_labels, pred_labels, config.classification_path ) roc_path = tester.plot_roc( "ROC Curve", true_labels=true_labels, pred_probabilites=predictions, ) log_classifier_test_results( tester=tester, run_name="cnn_test_results", confusion_matrix_path=confusion_Path, model_path=config.model_path, dataset_path=config.csv_path, classification_path=config.classification_path, roc_path=roc_path, true_labels=true_labels, pred_labels=pred_labels, predictions=predictions, ) else: print("No predictions generated - check data paths")
[docs] class TestMLP(BaseCommand): """Test MLP model performance.""" def _execute_command(self, config) -> None: predictor = TensionPredictor( model_path=config.model_path, image_folder=config.img_folder, tension_scaler_path=config.label_scaler_path, csv_path=config.csv_path, angle_threshold=config.angle_threshold, diameter_threshold=config.diameter_threshold, ) tensions, true_delta, predicted_deltas, predictions, mean = ( predictor.predict() ) log_regressor_test_results( model_path=config.model_path, run_name="mlp_results", experiment_name="mlp_results", dataset_path=config.csv_path, tensions=tensions, predicted_delta=predicted_deltas, predictions=predictions, true_delta=true_delta, mean=mean, )
[docs] class TestImageOnly(BaseCommand): """Test CNN model performance on only images.""" def _execute_command(self, config) -> None: tester = TestPredictions( model_path=config.model_path, csv_path=config.csv_path, img_folder=config.img_folder, scalar_path=None, image_only=True, backbone=config.backbone, encoder_path=config.encoder_path, classification_type=config.classification_type, angle_threshold=config.angle_threshold, diameter_threshold=config.diameter_threshold, ) true_labels, pred_labels, predictions = tester.gather_predictions() true_labels, pred_labels, predictions = tester.gather_predictions() if true_labels is not None: confusion_Path = tester.display_confusion_matrix( true_labels, pred_labels, model_path=config.model_path ) tester.display_classification_report( true_labels, pred_labels, config.classification_path ) roc_path = tester.plot_roc( "ROC Curve", true_labels=true_labels, pred_probabilites=predictions, ) log_classifier_test_results( tester=tester, run_name="image_only_test_results", confusion_matrix_path=confusion_Path, model_path=config.model_path, dataset_path=config.csv_path, classification_path=config.classification_path, roc_path=roc_path, true_labels=true_labels, pred_labels=pred_labels, predictions=predictions, image_only=True, ) else: print("No predictions generated - check data paths")
[docs] class TestXGBoost(BaseCommand): """Test XGBoost performance on prediciting change in tension.""" def _execute_command(self, config): xgb_predicter = XGBoostPredictor( xgb_path=config.xgb_path, csv_path=config.csv_path, angle_threshold=config.angle_threshold, diameter_threshold=config.diameter_threshold, scaler_path=config.label_scaler_path, cnn_model_path=config.model_path, ) xgb_predicter.load() tensions, predicted_deltas, predictions, true_delta, mean = ( xgb_predicter.predict() ) log_regressor_test_results( model_path=config.xgb_path, run_name="xgb_results", experiment_name="xgb_results", dataset_path=config.csv_path, tensions=tensions, predicted_delta=predicted_deltas, predictions=predictions, true_delta=true_delta, mean=mean, )
[docs] class TestRL(BaseCommand): """Testing logic for RL agent""" def _execute_command(self, config) -> None: rl_tester = TestAgent( csv_path=config.csv_path, cnn_path=config.cnn_path, img_folder=config.img_folder, agent_path=config.agent_path, threshold=config.threshold, feature_shape=config.feature_shape, max_steps=config.max_steps, low_range=config.low_range, high_range=config.high_range, max_delta=config.max_delta, max_tension_change=config.max_tension_change, ) info = rl_tester.test_agent(episodes=config.episodes) log_rl_test(config, run_name=config.run_name, info=info)