focal.mlp_model module¶
Main module for defining MLP training logic.
- class focal.mlp_model.BuildMLPModel(cnn_model_path: str, train_ds: Any, test_ds: Any, num_classes: int)[source]¶
Bases:
CustomModelMLP model for tension prediction using features from pre-trained CNN.
This class builds regression models that use features extracted from a CNN to predict optimal tension values for fiber cleaving.
- compile_model(dense1: int, dense2: int, dropout1: float, dropout2: float, dropout3: float, param_shape: Tuple[int, ...], learning_rate: float = 0.001, metrics: List[str] | None = None) Model[source]¶
Compile MLP model for regression.
- Parameters:
dense1 – Number of neurons in first FC layer.
dense2 – Number of neurons in second FC layer.
dropout1 – Amount of dropout for CNN output
dropout2 – Amount of dropout for feature input.
dropout3 – Amount of dropout for concatenated image+features
param_shape – Dimensions of numerical parameters
learning_rate – Learning rate for optimization
metrics – List of metrics to monitor
- Returns:
Compiled regression model
- Return type:
tf.keras.Model
- create_checkpoints(checkpoint_filepath: str = './checkpoints/mlp_model.keras', monitor: str = 'val_mae', mode: str = 'min', save_best_only: bool = True) ModelCheckpoint[source]¶
Create model checkpoints for MLP model.
- Parameters:
checkpoint_filepath – Path to save model checkpoints
monitor – Metric to monitor during training
mode – Method to determine stopping point (min for regression)
save_best_only – Whether to save only the best model
- Returns:
Checkpoint callback
- Return type:
tf.keras.callbacks.ModelCheckpoint
- create_early_stopping(patience: int = 3, mode: str = 'min', monitor: str = 'val_mae') EarlyStopping[source]¶
Create early stopping callback for regression model.
- Parameters:
patience – Number of epochs to wait before stopping
mode – Method to track monitor (min for regression)
monitor – Metric to monitor during training
- Returns:
Early stopping callback
- Return type:
tf.keras.callbacks.EarlyStopping
- static get_averages_from_kfold(kfold_histories: List[History], scaler: any) None[source]¶
Calculate and display average metrics from k-fold cross validation for MLP.
- Parameters:
kfold_histories – List of training histories from k-fold training
scaler – Scaler used for tension values (for denormalization)
- static train_kfold_mlp(datasets: List[Tuple], cnn_model_path: str, param_shape: Tuple[int, ...], learning_rate: float, num_classes: int, dense1: int, dense2: int, dropout1: float, dropout2: float, dropout3: float, checkpoints: ModelCheckpoint | None = None, epochs: int = 5, initial_epoch: int = 0, early_stopping: EarlyStopping | None = None, history_file: str | None = None, save_model_file: str | None = None) Tuple[List[Model], List[History]][source]¶
Train MLP model using k-fold cross validation.
- Parameters:
datasets – List of (train_ds, test_ds) tuples for each fold
cnn_model_path – Path to the pre-trained CNN model
param_shape – Dimensions of numerical parameters
learning_rate – Learning rate for optimization
checkpoints – Checkpoint callback
epochs – Number of training epochs
initial_epoch – Starting epoch number
early_stopping – Early stopping callback
history_file – Base filename for saving training history
model_file – Base filename for saving models
- Returns:
Tuple of (list of trained models, list of training histories)