focal.prediction_testing module¶
Prediction model pipeline for testing CNN model or MLP model.
This module provides classes for gathering data and then testing on either the cnn model or the regression model.
- class focal.prediction_testing.TensionPredictor(model_path: str, image_folder: str, csv_path: str, angle_threshold: float, diameter_threshold: float, tension_scaler_path: str | None = None)[source]¶
Bases:
objectPredicts tension values using a trained MLP model and preprocessed image/features.
- load_and_preprocess_image(file_path: str, img_folder: str) Tensor[source]¶
Load and preprocess image from file path.
- Parameters:
file_path (str) – Path to image file.
img_folder (str) – Path to image folder.
- Returns:
Preprocessed image tensor.
- Return type:
tf.Tensor
- plot_metric(title: str, X: list[float], y: list[float], x_label: str, y_label: str, x_legend: str, y_legend: str) None[source]¶
Plot a metric for model evaluation.
- Parameters:
title (str) – Title of plot.
X (list[float]) – List of x values.
y (list[float]) – List of y values.
x_label (str) – Label for x axis.
y_label (str) – Label for y axis.
x_legend (str) – Legend for x axis.
y_legend (str) – Legend for y axis.
- class focal.prediction_testing.TestPredictions(model_path: str, csv_path: str, scalar_path: str, img_folder: str, angle_threshold: float, diameter_threshold: float, encoder_path: str | None = None, image_only: bool = False, backbone: str = 'efficientnet', classification_type: str = 'binary', threshold: float | None = 0.5)[source]¶
Bases:
DataCollectorThis class is used to test model performance on unseen data using metrics such as accuracy, precision, recall, and confusion matrix.
Supports both image+feature and image-only CNNs.
- display_classification_report(true_labels: ndarray, pred_labels: List[int], classification_path: str | None = None) None[source]¶
Display classification report comparing true labels to predicted labels.
- Parameters:
true_labels (np.ndarray) – Array of true labels.
pred_labels (list[int]) – List of predicted labels.
classification_path (str, optional) – Optional path to save classification report.
- display_confusion_matrix(true_labels: ndarray, pred_labels: list[int], model_path: str) None[source]¶
Display confusion matrix comparing true labels to predicted labels.
- Parameters:
true_labels (np.ndarray) – Array of true labels.
pred_labels (list[int]) – List of predicted labels.
model_path (str) – Path to trained model.
- gather_predictions() Tuple[ndarray, list, list] | Tuple[None, None, None][source]¶
Gather multiple predictions from test data.
- Returns:
(true_labels, pred_labels, predictions) or (None, None, None) if no data.
- Return type:
tuple
- plot_roc(title: str, true_labels: ndarray, pred_probabilites: ndarray) None[source]¶
Plot ROC curve for predictions.
- Parameters:
title (str) – Title for the plot.
true_labels (np.ndarray) – Array of true labels.
pred_probabilites (np.ndarray) – Array of predicted probabilities.
- test_prediction(image_path: str, feature_vector: ndarray | None = None) ndarray[source]¶
Generate prediction for a single image (and features if not image_only).
- Parameters:
image_path (str) – Path to image to predict.
feature_vector (np.ndarray | None) – Numerical features (ignored if image_only).
- Returns:
Model prediction.
- Return type:
np.ndarray
- class focal.prediction_testing.TestTensionPredictions(cnn_model_path: str, tension_model_path: str, csv_path: str, scaler_path: str, img_folder: str, tension_threshold: int, image_only: bool)[source]¶
Bases:
BadCleaveTensionClassifier- gather_predictions(pred_features=None)[source]¶
Gather predictions for image-only or image+feature based classification/regression.
- Parameters:
pred_image_paths (list) – Paths to the images for prediction.
pred_features (np.ndarray, optional) – Feature vectors for each image.
- Returns:
Tuple of (true_labels, pred_labels)