focal.prediction_testing module

Prediction model pipeline for testing CNN model or MLP model.

This module provides classes for gathering data and then testing on either the cnn model or the regression model.

class focal.prediction_testing.TensionPredictor(model_path: str, image_folder: str, csv_path: str, angle_threshold: float, diameter_threshold: float, tension_scaler_path: str | None = None)[source]

Bases: object

Predicts tension values using a trained MLP model and preprocessed image/features.

load_and_preprocess_image(file_path: str, img_folder: str) Tensor[source]

Load and preprocess image from file path.

Parameters:
  • file_path (str) – Path to image file.

  • img_folder (str) – Path to image folder.

Returns:

Preprocessed image tensor.

Return type:

tf.Tensor

plot_metric(title: str, X: list[float], y: list[float], x_label: str, y_label: str, x_legend: str, y_legend: str) None[source]

Plot a metric for model evaluation.

Parameters:
  • title (str) – Title of plot.

  • X (list[float]) – List of x values.

  • y (list[float]) – List of y values.

  • x_label (str) – Label for x axis.

  • y_label (str) – Label for y axis.

  • x_legend (str) – Legend for x axis.

  • y_legend (str) – Legend for y axis.

predict()[source]

Run tension predictions on filtered cleave data.

class focal.prediction_testing.TestPredictions(model_path: str, csv_path: str, scalar_path: str, img_folder: str, angle_threshold: float, diameter_threshold: float, encoder_path: str | None = None, image_only: bool = False, backbone: str = 'efficientnet', classification_type: str = 'binary', threshold: float | None = 0.5)[source]

Bases: DataCollector

This class is used to test model performance on unseen data using metrics such as accuracy, precision, recall, and confusion matrix.

Supports both image+feature and image-only CNNs.

display_classification_report(true_labels: ndarray, pred_labels: List[int], classification_path: str | None = None) None[source]

Display classification report comparing true labels to predicted labels.

Parameters:
  • true_labels (np.ndarray) – Array of true labels.

  • pred_labels (list[int]) – List of predicted labels.

  • classification_path (str, optional) – Optional path to save classification report.

display_confusion_matrix(true_labels: ndarray, pred_labels: list[int], model_path: str) None[source]

Display confusion matrix comparing true labels to predicted labels.

Parameters:
  • true_labels (np.ndarray) – Array of true labels.

  • pred_labels (list[int]) – List of predicted labels.

  • model_path (str) – Path to trained model.

gather_predictions() Tuple[ndarray, list, list] | Tuple[None, None, None][source]

Gather multiple predictions from test data.

Returns:

(true_labels, pred_labels, predictions) or (None, None, None) if no data.

Return type:

tuple

plot_roc(title: str, true_labels: ndarray, pred_probabilites: ndarray) None[source]

Plot ROC curve for predictions.

Parameters:
  • title (str) – Title for the plot.

  • true_labels (np.ndarray) – Array of true labels.

  • pred_probabilites (np.ndarray) – Array of predicted probabilities.

test_prediction(image_path: str, feature_vector: ndarray | None = None) ndarray[source]

Generate prediction for a single image (and features if not image_only).

Parameters:
  • image_path (str) – Path to image to predict.

  • feature_vector (np.ndarray | None) – Numerical features (ignored if image_only).

Returns:

Model prediction.

Return type:

np.ndarray

class focal.prediction_testing.TestTensionPredictions(cnn_model_path: str, tension_model_path: str, csv_path: str, scaler_path: str, img_folder: str, tension_threshold: int, image_only: bool)[source]

Bases: BadCleaveTensionClassifier

gather_predictions(pred_features=None)[source]

Gather predictions for image-only or image+feature based classification/regression.

Parameters:
  • pred_image_paths (list) – Paths to the images for prediction.

  • pred_features (np.ndarray, optional) – Feature vectors for each image.

Returns:

Tuple of (true_labels, pred_labels)

predict_tension(image_path: str, params=None)[source]