Source code for focal.model_pipeline

"""Model pipeline module for the FOCAL application.

This module provides classes for building, training, and managing CNN
model for fiber cleave quality classification. It uses binary classification to place
the image as either a good or bad cleave.
"""

import os
import warnings
from typing import List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Suppress TensorFlow warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.filterwarnings("ignore")

try:
    import tensorflow as tf
    from tensorflow.keras.applications import (
        EfficientNetB0,
        MobileNetV2,
        ResNet50,
    )
    from tensorflow.keras.callbacks import (
        EarlyStopping,
        ModelCheckpoint,
        ReduceLROnPlateau,
        TensorBoard,
    )
    from tensorflow.keras.layers import (
        Activation,
        BatchNormalization,
        Concatenate,
        Conv2D,
        Dense,
        Dropout,
        GlobalAveragePooling2D,
        Input,
        MaxPooling2D,
        RandomBrightness,
        RandomContrast,
        RandomRotation,
        RandomZoom,
    )
    from tensorflow.keras.models import Model, Sequential
    from tensorflow.keras.regularizers import l2
except ImportError as e:
    print(f"Warning: TensorFlow not found: {e}")
    print("Please install tensorflow>=2.19.0")
    tf = None


[docs] class CustomModel: """Class for defining custom models using pre-trained MobileNetV2. This class provides functionality for building, compiling, and training CNN models for fiber cleave classification. """ def __init__( self, train_ds: "tf.data.Dataset", test_ds: "tf.data.Dataset", num_classes: int, classification_type: Optional[str] = "binary", ) -> None: """Initialize the custom model. Args: train_ds: Training dataset test_ds: Test dataset num_classes: Number of output classes classification_type: Type of classification """ if tf is None: raise ImportError("TensorFlow is required for CustomModel") self.train_ds = train_ds self.test_ds = test_ds self.classification_type = classification_type self.num_classes = num_classes def _get_backbone_model( self, backbone: str, image_shape: Tuple[int, int, int] ) -> tf.keras.Model: """Get pretrained backbone model based on specified backbone type. Args: backbone: Type of backbone model ("mobilenet", "resnet", "efficientnet") image_shape: Input image shape (height, width, channels) Returns: tf.keras.Model: Pretrained backbone model Raises: ValueError: If backbone type is not supported """ if backbone == "mobilenet": pre_trained_model = MobileNetV2( input_shape=image_shape, include_top=False, weights="imagenet", name="mobilenet", ) elif backbone == "resnet": pre_trained_model = ResNet50( input_shape=image_shape, include_top=False, weights="imagenet", name="resnet50", ) elif backbone == "efficientnet": pre_trained_model = EfficientNetB0( input_shape=image_shape, include_top=False, weights="imagenet", name="efficientnetb0", ) else: raise ValueError( f"Unsupported backbone: {backbone}. Supported backbones: mobilenet, resnet, efficientnet" ) return pre_trained_model
[docs] def get_data_augmentation( self, rotation: float, brightness: float, height: float, width: float, contrast: float, ): """Get data augmentation parameters for training model. Args: rotation (float): Rotation amount of image. brightness (float): Brightness amount of image. height (float): Height of image. width (float): Widht of image. contrast (float): Contrast amount of image. Returns: tf.keras.models.Model: Data augmentation model. """ data_augmentation = Sequential( [ RandomRotation(factor=rotation), RandomBrightness(factor=brightness), RandomZoom(height_factor=height, width_factor=width), RandomContrast(contrast), ] ) return data_augmentation
def _build_custom_model( self, image_shape: Tuple[int, int, int], num_classes: int = 5 ) -> Model: """Build custom model architecture. Args: image_shape (Tuple[int, int, int]): Shape of image pixel vector. num_classes (int, optional): Num of classifcation classes. Defaults to 5. Returns: Model: Custom model. """ data_augmentation = Sequential( [ RandomRotation(factor=0.02), RandomBrightness(factor=0.02), ] ) image_input = Input(shape=image_shape) x = data_augmentation(image_input) x = Conv2D(16, (5, 5), padding="same")(x) x = BatchNormalization()(x) x = Activation("relu")(x) x = MaxPooling2D(pool_size=(4, 4))(x) x = Dropout(0.25)(x) x = Conv2D(32, (3, 3), padding="same")(x) x = BatchNormalization()(x) x = Activation("relu")(x) x = GlobalAveragePooling2D()(x) x = Dense(16, activation="relu")(x) x = Dropout(0.5)(x) output = Dense(num_classes, activation="softmax")(x) model = Model(inputs=image_input, outputs=output) return model def _build_pretrained_model( self, image_shape: Tuple[int, int, int], param_shape: Tuple[int, ...], dropout1: float, dense1: int, dropout2: float, dense2: int, dropout3: float, brightness: float, contrast: float, height: float, width: float, rotation: float, backbone: Optional[str] = "mobilenet", unfreeze_from: Optional[int] = None, ) -> "tf.keras.Model": """Build a model using pre-trained EfficientNetB0 to supplement small dataset. Args: image_shape: Dimensions of input images (height, width, channels) param_shape: Dimensions of numerical parameters unfreeze_from: Layer index from which to unfreeze weights (None = all frozen) dropout1: Perentage of inputs to zero out dense1: Number of neurons in first fully connected layer dropout2: Percentage of inputs to zero out dense2: Number of neurons in second fully connected layer dropout3: Percentage of final inputs to zero out Returns: tf.keras.Model: Compiled model ready for training """ pre_trained_model = self._get_backbone_model( backbone=backbone, image_shape=image_shape ) pre_trained_model.trainable = unfreeze_from is not None if unfreeze_from is not None: for layer in pre_trained_model.layers[:unfreeze_from]: layer.trainable = False # Data augmentation pipeline data_augmentation = self.get_data_augmentation( rotation=rotation, brightness=brightness, height=height, width=width, contrast=contrast, ) # CNN for images image_input = Input(shape=image_shape) x = data_augmentation(image_input) x = pre_trained_model(x) x = GlobalAveragePooling2D(name="global_avg")(x) x = Dropout(dropout1, name="dropout_1")(x) # Numerical features section params_input = Input(shape=param_shape) y = Dense(dense1, name="dense_1", activation="relu")(params_input) y = Dropout(dropout2, name="dropout_2")( y ) # Added to remove reliance on features y = BatchNormalization()(y) combined = Concatenate()([x, y]) z = Dense(dense2, name="dense_2", activation="relu")(combined) z = BatchNormalization(name="batch_norm")(z) z = Dropout(dropout3, name="dropout_3")(z) if self.classification_type == "binary": activation = "sigmoid" elif self.classification_type == "multiclass": activation = "softmax" z = Dense( self.num_classes, name="output_layer", activation=activation, )(z) model = Model(inputs=[image_input, params_input], outputs=z) model.summary() return model def _build_image_only_model( self, image_shape: Tuple[int, int, int], backbone: Optional[str] = "efficientnet", num_classes: int = 5, rotation: Optional[float] = 0.0, height: Optional[float] = 0.0, width: Optional[float] = 0.0, contrast: Optional[float] = 0.0, brightness: Optional[float] = 0.0, dropout1: Optional[float] = 0.1, dense1: Optional[int] = 32, dropout2: Optional[float] = 0.2, l2_factor: Optional[float] = None, unfreeze_from: Optional[int] = None, ) -> "tf.keras.Model": """Build a model that uses only the image input (no parameter features). Args: image_shape: Dimensions of input images backbone: Type of base model to use dropout1_rate: level of dropout for first layer dense1: Units for first hidden layer dropout2_rate: Level of dropout for second layer height: Height of augmented image. width: Width of augmented image. brightness: Amount of brightness for augmented image. contrast: Contrast amount for augmented image. unfreeze_from: Layer to unfreeze backbone model from. Returns: tf.keras.Model: Image-only classification model """ pre_trained_model = self._get_backbone_model( backbone=backbone, image_shape=image_shape ) pre_trained_model.trainable = unfreeze_from is not None if unfreeze_from is not None: for layer in pre_trained_model.layers[:unfreeze_from]: layer.trainable = False data_augmentation = self.get_data_augmentation( rotation=rotation, brightness=brightness, height=height, width=width, contrast=contrast, ) image_input = Input(shape=image_shape) x = data_augmentation(image_input) x = pre_trained_model(x) x = GlobalAveragePooling2D(name="global_avg")(x) x = Dropout(dropout1, name="dropout_1")(x) if l2_factor: x = Dense( dense1, name="dense1", activation="relu", kernel_regularizer=l2(l2_factor), )(x) else: x = Dense(dense1, name="dense1", activation="relu")(x) x = Dropout(dropout2, name="dropout_2")(x) if self.classification_type == "binary": activation = "sigmoid" elif self.classification_type == "multiclass": activation = "softmax" output = Dense( num_classes, name="output_layer", activation=activation )(x) model = Model(inputs=image_input, outputs=output) model.summary() return model
[docs] def compile_image_only_model( self, image_shape: Tuple[int, int, int], learning_rate: float = 0.001, metrics: Optional[List[str]] = None, backbone: Optional[str] = "mobilenet", num_classes: int = 5, dropout1: Optional[float] = 0.1, dense1: Optional[int] = 32, dropout2: Optional[float] = 0.2, l2_factor: Optional[float] = None, unfreeze_from: Optional[int] = None, ) -> "tf.keras.Model": """Compile an image-only model. Args: image_shape: Dimensions of input images learning_rate: Learning rate for optimization metrics: List of metrics to monitor Returns: tf.keras.Model: Compiled image-only model """ if metrics is None: metrics = [ "accuracy", tf.keras.metrics.Precision(), tf.keras.metrics.Recall(), ] model = self._build_image_only_model( image_shape, backbone=backbone, dropout1=dropout1, dense1=dense1, dropout2=dropout2, num_classes=num_classes, l2_factor=l2_factor, unfreeze_from=unfreeze_from, ) optimizer = tf.keras.optimizers.AdamW(learning_rate=learning_rate) if self.classification_type == "binary": loss = "binary_crossentropy" elif self.classification_type == "multiclass": loss = "categorical_crossentropy" model.compile(optimizer=optimizer, loss=loss, metrics=metrics) return model
[docs] def compile_custom_model( self, image_shape: Tuple[int, int, int], learning_rate: float = 0.001, metrics: Optional[List[str]] = None, num_classes: Optional[int] = 5, ) -> "tf.keras.Model": """Compile custom model after calling build_custom_model function. Args: image_shape: Dimensions of input images param_shape: Dimensions of numerical parameters learning_rate: Learning rate for optimization metrics: List of metrics to monitor num_class: number of output classes Returns: tf.keras.Model: Compiled model ready for training """ if metrics is None: metrics = ["accuracy"] model = self._build_custom_model(image_shape, num_classes=num_classes) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) if self.classification_type == "binary": loss = "binary_crossentropy" elif self.classification_type == "multiclass": loss = "categorical_crossentropy" model.compile(optimizer=optimizer, loss=loss, metrics=metrics) return model
[docs] def compile_model( self, image_shape: Tuple[int, int, int], param_shape: Tuple[int, ...], dropout1: float, dense1: int, dropout2: float, dense2: int, dropout3: float, brightness: float, height: float, width: float, contrast: float, rotation: float, learning_rate: float = 0.001, metrics: Optional[List[str]] = None, unfreeze_from: Optional[int] = None, backbone: Optional[str] = "mobilenet", ) -> Model: """Compile the custom model head on top of pre-trained backbone. Args: image_shape (Tuple[int, int, int]): Shape of each image. param_shape (Tuple[int, ...]): Shape of feature vector. dropout1 (float): Amount of dropout for image layer. dense1 (int): Amount of neurons for first FC layer. dropout2 (float): Amount of dropout for feature input. dense2 (int): Amount of neruons for second FC layer. dropout3 (float): Amount of dropout for concatenated images+features. brightness (float): Brightness amount for image. height (float): Height of augmented image. width (float): Width of augmented image. contrast (float): Contrast amount of augmented image. rotation (float): Rotation amount of augmented image. learning_rate (float, optional): Size of steps to take during optimization. Defaults to 0.001. metrics (Optional[List[str]], optional): Saved metrics. Defaults to None. unfreeze_from (Optional[int], optional): Layer to start training pre-trained backbone. Defaults to None. backbone (Optional[str], optional): Name of pre-trained backbone. Defaults to "mobilenet". Returns: Model: Compiled custom head on pre-trained model backbone. """ if metrics is None: metrics = [ "accuracy", tf.keras.metrics.Precision(), tf.keras.metrics.Recall(), ] model = self._build_pretrained_model( image_shape, param_shape, unfreeze_from=unfreeze_from, backbone=backbone, dense1=dense1, dense2=dense2, dropout1=dropout1, dropout2=dropout2, dropout3=dropout3, brightness=brightness, height=height, width=width, rotation=rotation, contrast=contrast, ) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) if self.classification_type == "binary": loss = "binary_crossentropy" elif self.classification_type == "multiclass": loss = "categorical_crossentropy" model.compile(optimizer=optimizer, loss=loss, metrics=metrics) return model
[docs] def create_checkpoints( self, checkpoint_filepath: str = "./checkpoints/model.keras", monitor: str = "val_accuracy", mode: str = "max", save_best_only: bool = True, ) -> ModelCheckpoint: """Create model checkpoints to avoid losing data while training. Args: checkpoint_filepath: Path to save model checkpoints monitor: Metric to monitor during training mode: Method to determine stopping point of metric (max, min, auto) save_best_only: Whether to save only the best model Returns: tf.keras.callbacks.ModelCheckpoint: Checkpoint callback """ # Ensure directory exists os.makedirs(os.path.dirname(checkpoint_filepath), exist_ok=True) model_checkpoint_callback = ModelCheckpoint( filepath=checkpoint_filepath, monitor=monitor, mode=mode, save_best_only=save_best_only, verbose=1, ) return model_checkpoint_callback
[docs] def reduce_on_plateau( self, patience: int = 3, mode: str = "auto", factor: float = 2.0, monitor: str = "val_accuracy", ) -> ReduceLROnPlateau: """Create reduced learning rate callback if monitored value stops improving. Args: patience: Number of epochs before reducing learning rate mode: Method to monitor (min, max, auto) factor: Factor by which to reduce learning rate monitor: Value to monitor before reducing learning rate Returns: tf.keras.callbacks.ReduceLROnPlateau: Reduce learning rate callback """ reduce_lr = ReduceLROnPlateau( monitor=monitor, patience=patience, mode=mode, factor=factor, min_lr=1e-7, ) return reduce_lr
[docs] def create_early_stopping( self, patience: int = 3, mode: str = "max", monitor: str = "val_accuracy", ) -> EarlyStopping: """Create early stopping callback to monitor training success and prevent overfitting. Args: patience: Number of epochs to wait before stopping when monitor plateaus mode: Method to track monitor (max, min, auto) monitor: Metric to monitor during training Returns: tf.keras.callbacks.EarlyStopping: Early stopping callback """ es_callback = EarlyStopping( monitor=monitor, patience=patience, mode=mode, restore_best_weights=True, verbose=1, ) return es_callback
[docs] def create_tensorboard_callback( self, log_dir: str = "./logs", histogram_freq: int = 1 ) -> TensorBoard: """Create TensorBoard callback for monitoring training. Args: log_dir: Directory for TensorBoard logs histogram_freq: Frequency for computing weight histograms Returns: tf.keras.callbacks.TensorBoard: TensorBoard callback """ return TensorBoard(log_dir=log_dir, histogram_freq=histogram_freq)
[docs] def train_model( self, class_weights, model: tf.keras.Model, epochs: int = 5, initial_epoch: int = 0, callbacks: Optional[List] = None, history_file: Optional[str] = None, save_model_file: Optional[str] = None, ) -> tf.keras.callbacks.History: """Train model with possible callbacks to prevent overfitting. Args: model: Model to be trained checkpoints: Checkpoints to save model epochs: Number of training epochs initial_epoch: Starting epoch number early_stopping: Early stopping callback reduce_lr: Reduce learning rate callback tensorboard: TensorBoard callback history_file: File to save training history model_file: File to save trained model Returns: tf.keras.callbacks.History: Training history """ if callbacks: history = model.fit( self.train_ds, epochs=epochs, initial_epoch=initial_epoch, validation_data=self.test_ds, callbacks=callbacks, class_weight=class_weights, ) else: print("Training without callbacks") history = model.fit( self.train_ds, epochs=epochs, initial_epoch=initial_epoch, validation_data=self.test_ds, class_weight=class_weights, ) # Save training history if history_file: os.makedirs(os.path.dirname(history_file), exist_ok=True) df = pd.DataFrame(history.history) df.to_csv(history_file, index=False) print(f"Training history saved to: {history_file}") else: print("Training history not saved") # Save trained model if save_model_file: os.makedirs(os.path.dirname(save_model_file), exist_ok=True) model.save(save_model_file) print(f"Model saved to: {save_model_file}") else: print("Model not saved") return history
[docs] @staticmethod def train_kfold( datasets: List[Tuple], image_shape: Tuple[int, int, int], param_shape: Tuple[int, ...], learning_rate: float, num_classes: int, dropout1: float, dense1: int, dropout2: float, dense2: int, dropout3: float, brightness: float, height: float, width: float, contrast: float, rotation: float, metrics: List[str] = None, epochs: int = 5, initial_epoch: int = 0, history_file: Optional[str] = None, save_model_file: Optional[str] = None, callbacks: Optional[List] = None, ) -> Tuple[List[tf.keras.Model], List[tf.keras.callbacks.History]]: """Train model using k-fold cross validation. Args: datasets: List of (train_ds, test_ds) tuples for each fold image_shape: Dimensions of input images param_shape: Dimensions of numerical parameters dropout1: Amount of dropout for image layer. dense1: Amount of neurons for first FC layer. dropout2: Amount of dropout for feature input. dense2: Amount of neruons for second FC layer. dropout3: Amount of dropout for concatenated images+features. brightness: Brightness amount for image. height: Height of augmented image. width: Width of augmented image. contrast: Contrast amount of augmented image. rotation: Rotation amount of augmented image. learning_rate: Learning rate for optimization. metrics: List of metrics to monitor. epochs: Number of training epochs. initial_epoch: Starting epoch number. history_file: Base filename for saving training history. model_file: Base filename for saving models. Returns: Tuple of (list of trained models, list of training histories) """ if metrics is None: metrics = ["accuracy", "precision", "recall"] kfold_histories = [] k_models = [] train_datasets = [i[0] for i in datasets] test_datasets = [i[1] for i in datasets] for fold, (train_ds, test_ds) in enumerate( zip(train_datasets, test_datasets) ): print(f"\n=== Training fold {fold + 1} ===") custom_model = CustomModel(train_ds, test_ds, num_classes) model = custom_model.compile_model( image_shape=image_shape, param_shape=param_shape, learning_rate=learning_rate, dropout1=dropout1, dense1=dense1, dropout2=dropout2, dense2=dense2, dropout3=dropout3, brightness=brightness, height=height, width=width, contrast=contrast, rotation=rotation, metrics=metrics, ) es = EarlyStopping( monitor="val_accuracy", patience=4, restore_best_weights=True, verbose=1, ) callbacks = [es] if callbacks: history = model.fit( train_ds, epochs=epochs, initial_epoch=initial_epoch, validation_data=test_ds, callbacks=callbacks, ) else: print("Training without callbacks") history = model.fit( train_ds, epochs=epochs, initial_epoch=initial_epoch, validation_data=test_ds, ) kfold_histories.append(history) k_models.append(model) # Save fold-specific history and model if history_file: os.makedirs(os.path.dirname(history_file), exist_ok=True) df = pd.DataFrame(history.history) df.to_csv(f"{history_file}_fold{fold+1}.csv", index=False) print(f"Fold {fold+1} history saved") else: print("History not saved") if save_model_file: os.makedirs(os.path.dirname(save_model_file), exist_ok=True) save_model_file = save_model_file.strip(".keras") model.save(f"{save_model_file}_fold{fold+1}.keras") print(f"Fold {fold+1} model saved") else: print("Model not saved") return k_models, kfold_histories
[docs] @staticmethod def get_averages_from_kfold( kfold_histories: List[tf.keras.callbacks.History], ) -> None: """Calculate and display average metrics from k-fold cross validation. Args: kfold_histories: List of training histories from k-fold training """ accuracy = [] precision = [] recall = [] for history in kfold_histories: accuracy.append(max(history.history["val_accuracy"])) precision.append(max(history.history["val_precision"])) recall.append(max(history.history["val_recall"])) avg_accuracy = np.mean(accuracy) avg_precision = np.mean(precision) avg_recall = np.mean(recall) print( f"Average Accuracy: {avg_accuracy:.4f}, Std Dev: {np.std(accuracy):.4f}" ) print( f"Average Precision: {avg_precision:.4f}, Std Dev: {np.std(precision):.4f}" ) print( f"Average Recall: {avg_recall:.4f}, Std Dev: {np.std(recall):.4f}" )
[docs] def plot_metric( self, title: str, metric_1: List[float], metric_2: List[float], metric_1_label: str, metric_2_label: str, x_label: str, y_label: str, model_path: str, ) -> None: """Plot training metrics for visualization. Args: title: Title for the plot metric_1: First metric values to plot metric_2: Second metric values to plot metric_1_label: Label for first metric metric_2_label: Label for second metric x_label: Label for x-axis y_label: Label for y-axis """ plt.figure(figsize=(10, 6)) plt.title(title) plt.plot(metric_1, label=metric_1_label) plt.plot(metric_2, label=metric_2_label) plt.xlabel(x_label) plt.ylabel(y_label) plt.legend(loc="lower right") plt.grid(True, alpha=0.3) plt.tight_layout() model_dir = os.path.dirname(model_path) basename = os.path.basename(model_path) stem, _ = os.path.splitext(basename) save_plot = os.path.join(model_dir, f"{stem}_{title}.png") plt.savefig(save_plot) plt.show()