"""Module for defining training logic to call from main entry point"""
import traceback
import tensorflow as tf
from focal.data_processing import (
BadCleaveTensionClassifier,
DataCollector,
MLPDataCollector,
)
from focal.mlflow_utils import (
log_cnn_training_run,
log_image_training_run,
log_mlp_training_run,
log_xgb_training_run,
)
from focal.mlp_model import BuildMLPModel
from focal.model_pipeline import CustomModel
from focal.rl_pipeline import TrainAgent
from focal.xgboost_pipeline import XGBoostModel
from .base_command import BaseCommand
from .utils import _setup_callbacks
[docs]
class TrainCNN(BaseCommand):
"""Train a CNN model for fiber cleave classification."""
def _execute_command(self, config) -> None:
if tf is None:
raise ImportError("TensorFlow is required for CNN training")
if config.cnn_mode == "bad_good":
data = DataCollector(
config.csv_path,
config.img_folder,
classification_type=config.classification_type,
backbone=config.backbone,
angle_threshold=config.angle_threshold,
diameter_threshold=config.diameter_threshold,
)
elif config.cnn_mode == "tension":
data = BadCleaveTensionClassifier(
csv_path=config.csv_path,
img_folder=config.img_folder,
backbone=config.backbone,
tension_threshold=config.tension_threshold,
)
else:
raise ValueError(f"Unsupported cnn mode: {config.cnn_mode}")
images, features, labels = data.extract_data()
train_ds, test_ds, class_weights = data.create_datasets(
images,
features,
labels,
config.test_size,
config.buffer_size,
config.batch_size,
feature_scaler_path=config.feature_scaler_path,
train_p=config.train_p,
test_p=config.test_p,
)
trainable_model = CustomModel(
train_ds,
test_ds,
classification_type=config.classification_type,
num_classes=config.num_classes,
)
if config.continue_train == "y":
compiled_model = tf.keras.models.load_model(config.model_path)
else:
compiled_model = trainable_model.compile_model(
image_shape=config.image_shape,
param_shape=config.feature_shape,
learning_rate=config.learning_rate or 0.001,
unfreeze_from=config.unfreeze_from,
backbone=config.backbone,
dense1=config.dense1,
dense2=config.dense2,
dropout1=config.dropout1,
dropout2=config.dropout2,
dropout3=config.dropout3,
brightness=config.brightness,
contrast=config.contrast,
height=config.height,
width=config.width,
rotation=config.rotation,
)
# Setup callbacks
callbacks = _setup_callbacks(config, trainable_model)
max_epochs = config.max_epochs or 20
history = trainable_model.train_model(
class_weights=class_weights,
model=compiled_model,
epochs=max_epochs,
callbacks=callbacks,
history_file=config.save_history_file,
save_model_file=config.save_model_file,
)
log_cnn_training_run(
config,
compiled_model,
history,
dataset_path=config.csv_path,
artifacts={
"model": config.save_model_file,
"history": config.save_history_file,
},
)
# Plot training metrics
trainable_model.plot_metric(
"Loss vs. Val Loss",
history.history["loss"],
history.history["val_loss"],
"loss",
"val_loss",
"epochs",
"loss",
model_path=config.save_model_file,
)
trainable_model.plot_metric(
"Accuracy vs. Val Accuracy",
history.history["accuracy"],
history.history["val_accuracy"],
"accuracy",
"val_accuracy",
"epochs",
"accuracy",
model_path=config.save_model_file,
)
[docs]
class TrainMLP(BaseCommand):
"""Train an MLP model for tension prediction."""
def _execute_command(self, config) -> None:
if tf is None:
raise ImportError("TensorFlow is required for MLP training")
data = MLPDataCollector(
config.csv_path,
config.img_folder,
angle_threshold=config.angle_threshold,
diameter_threshold=config.diameter_threshold,
)
images, features, labels = data.extract_data()
train_ds, test_ds = data.create_datasets(
images,
features,
labels,
config.test_size,
config.buffer_size,
config.batch_size,
feature_scaler_path=config.feature_scaler_path,
tension_scaler_path=config.label_scaler_path,
)
trainable_model = BuildMLPModel(
config.model_path,
train_ds,
test_ds,
num_classes=config.num_classes,
)
# Setup callbacks
callbacks = _setup_callbacks(config, trainable_model)
max_epochs = config.max_epochs or 20
compiled_model = trainable_model.compile_model(
param_shape=config.feature_shape,
learning_rate=config.learning_rate,
dense1=config.dense1,
dense2=config.dense2,
dropout1=config.dropout1,
dropout2=config.dropout2,
dropout3=config.dropout3,
)
history = trainable_model.train_model(
class_weights=None,
model=compiled_model,
epochs=max_epochs,
callbacks=callbacks,
history_file=config.save_history_file,
save_model_file=config.save_model_file,
)
log_mlp_training_run(
config,
compiled_model,
history,
dataset_path=config.csv_path,
artifacts={
"model": config.save_model_file,
"history": config.save_history_file,
},
)
# Plot training metrics
trainable_model.plot_metric(
"Loss vs. Val Loss",
history.history["loss"],
history.history["val_loss"],
"loss",
"val_loss",
"epochs",
"loss",
model_path=config.save_model_file,
)
trainable_model.plot_metric(
"MAE vs. Val MAE",
history.history["mae"],
history.history["val_mae"],
"mae",
"val_mae",
"epochs",
"mae",
model_path=config.save_model_file,
)
[docs]
class TrainXGBoost(BaseCommand):
"""Train an XGBoost model for predicting delta in tension"""
def _execute_command(self, config):
data = MLPDataCollector(
csv_path=config.csv_path,
img_folder=config.img_folder,
backbone=None,
angle_threshold=config.angle_threshold,
diameter_threshold=config.diameter_threshold,
)
images, features, labels = data.extract_data()
train_ds, test_ds = data.create_datasets(
images,
features,
labels,
test_size=config.test_size,
batch_size=config.batch_size,
buffer_size=config.buffer_size,
feature_scaler_path=None,
tension_scaler_path=config.label_scaler_path,
)
train_ds = data.image_only_dataset(train_ds)
test_ds = data.image_only_dataset(test_ds)
xgb_model = XGBoostModel(
csv_path=config.csv_path,
cnn_model_path=config.model_path,
train_ds=train_ds,
test_ds=test_ds,
)
evals_result = xgb_model.train(
n_estimators=config.n_estimators,
learning_rate=config.learning_rate,
max_depth=config.max_depth,
random_state=config.random_state,
gamma=config.gamma,
subsample=config.subsample,
reg_lambda=config.reg_lambda,
)
xgb_model.save(config.xgb_path)
X_train, y_train = xgb_model._extract_features_and_labels(train_ds)
log_xgb_training_run(
config=config,
model=xgb_model.get_model(),
evals_result=evals_result,
X_train=X_train,
y_train=y_train,
dataset_path=config.csv_path,
artifacts={"model": config.xgb_path},
)
xgb_model.plot_metrics(
title="RMSE vs. Val RMSE",
metric1=evals_result["validation_0"]["rmse"],
metric2=evals_result["validation_1"]["rmse"],
metric1_label="RSME",
metric2_label="Val RMSE",
x_label="Training Round",
y_label="RMSE",
)
[docs]
class TrainImageOnly(BaseCommand):
"""Train the CNN model with only images."""
def _execute_command(self, config) -> None:
if tf is None:
raise ImportError("TensorFlow is required for image-only training")
try:
data = DataCollector(
config.csv_path,
config.img_folder,
classification_type=config.classification_type,
backbone=config.backbone,
set_mask=config.set_mask,
encoder_path=config.encoder_path,
angle_threshold=config.angle_threshold,
diameter_threshold=config.diameter_threshold,
)
images, features, labels = data.extract_data()
train_ds, test_ds, class_weights = data.create_datasets(
images,
features,
labels,
config.test_size,
config.buffer_size,
config.batch_size,
)
# Convert to image-only datasets
train_ds = data.image_only_dataset(train_ds)
test_ds = data.image_only_dataset(test_ds)
trainable_model = CustomModel(
train_ds,
test_ds,
classification_type=config.classification_type,
num_classes=config.num_classes,
)
if config.continue_train == "y":
compiled_model = tf.keras.models.load_model(config.model_path)
else:
compiled_model = trainable_model.compile_image_only_model(
config.image_shape,
config.learning_rate or 0.001,
backbone=config.backbone,
dropout1=config.dropout1,
dense1=config.dense1,
dropout2=config.dropout2,
l2_factor=config.l2_factor,
num_classes=config.num_classes,
unfreeze_from=config.unfreeze_from,
)
# Setup callbacks
callbacks = _setup_callbacks(config, trainable_model)
max_epochs = config.max_epochs or 20
history = trainable_model.train_model(
class_weights,
compiled_model,
epochs=max_epochs,
callbacks=callbacks,
history_file=config.save_history_file,
save_model_file=config.save_model_file,
)
log_image_training_run(
config,
compiled_model,
history,
dataset_path=config.csv_path,
artifacts={
"model": config.save_model_file,
"history": config.save_history_file,
},
)
# Plot training metrics
trainable_model.plot_metric(
"Loss vs. Val Loss",
history.history["loss"],
history.history["val_loss"],
"loss",
"val_loss",
"epochs",
"loss",
model_path=config.save_model_file,
)
trainable_model.plot_metric(
"Accuracy vs. Val Accuracy",
history.history["accuracy"],
history.history["val_accuracy"],
"accuracy",
"val_accuracy",
"epochs",
"accuracy",
model_path=config.save_model_file,
)
except Exception as e:
print(f"Error during image-only training: {e}")
traceback.print_exc()
raise
[docs]
class KFoldCNN(BaseCommand):
"""Train CNN model using k-fold cross validation."""
def _execute_command(self, config) -> None:
data = DataCollector(
config.csv_path,
config.img_folder,
backbone=config.backbone,
angle_threshold=config.angle_threshold,
diameter_threshold=config.diameter_threshold,
)
images, features, labels = data.extract_data()
datasets = data.create_kfold_datasets(
images,
features,
labels,
config.buffer_size,
config.batch_size,
train_p=config.train_p,
test_p=config.test_p,
)
_, kfold_histories = CustomModel.train_kfold(
datasets,
config.image_shape,
config.feature_shape,
config.learning_rate or 0.001,
num_classes=config.num_classes,
epochs=config.max_epochs,
history_file=config.save_history_file,
save_model_file=config.save_model_file,
dense1=config.dense1,
dense2=config.dense2,
dropout1=config.dropout1,
dropout2=config.dropout2,
dropout3=config.dropout3,
brightness=config.brightness,
contrast=config.contrast,
height=config.height,
width=config.width,
rotation=config.rotation,
)
CustomModel.get_averages_from_kfold(kfold_histories)
[docs]
class KFoldMLP(BaseCommand):
"""Train MLP model using k-fold cross validation."""
def _execute_command(self, config) -> None:
data = MLPDataCollector(
config.csv_path,
config.img_folder,
backbone=None,
angle_threshold=config.angle_threshold,
diameter_threshold=config.diameter_threshold,
)
images, features, labels = data.extract_data()
datasets, scaler = data.create_kfold_datasets(
images,
features,
labels,
config.buffer_size,
config.batch_size,
)
kfold_histories = BuildMLPModel.train_kfold_mlp(
datasets,
config.model_path,
config.feature_shape,
config.learning_rate or 0.001,
dense1=config.dense1,
dense2=config.dense2,
dropout1=config.dropout1,
dropout2=config.dropout2,
dropout3=config.dropout3,
num_classes=config.num_classes,
history_file=config.save_history_file,
save_model_file=config.save_model_file,
)
BuildMLPModel.get_averages_from_kfold(kfold_histories, scaler)
[docs]
class TrainCustomModel(BaseCommand):
"""Train a custom CNN model without pre-trained base."""
def _execute_command(self, config) -> None:
data = DataCollector(
config.csv_path,
config.img_folder,
angle_threshold=config.angle_threshold,
diameter_threshold=config.diameter_threshold,
)
train_ds, test_ds = data.create_custom_dataset(
config.image_shape,
config.test_size,
config.buffer_size,
config.batch_size,
)
trainable_model = CustomModel(train_ds, test_ds)
compiled_model = trainable_model.compile_custom_model(
config.image_shape, config.learning_rate
)
callbacks = _setup_callbacks(config, trainable_model)
max_epochs = config.max_epochs or 20
history = trainable_model.train_model(
model=compiled_model,
epochs=max_epochs,
class_weights=None,
callbacks=callbacks,
history_file=config.save_history_file,
save_model_file=config.save_model_file,
)
# Plot training metrics
trainable_model.plot_metric(
"Loss vs. Val Loss",
history.history["loss"],
history.history["val_loss"],
"loss",
"val_loss",
"epochs",
"loss",
model_path=config.save_model_file,
)
trainable_model.plot_metric(
"Accuracy vs. Val Accuracy",
history.history["accuracy"],
history.history["val_accuracy"],
"accuracy",
"val_accuracy",
"epochs",
"accuracy",
model_path=config.save_model_file,
)
[docs]
class TrainRL(BaseCommand):
"""Train the reinforement learning agent"""
def _execute_command(self, config) -> None:
rl_trainer = TrainAgent(
csv_path=config.csv_path,
cnn_path=config.cnn_path,
img_folder=config.img_folder,
threshold=config.threshold,
feature_shape=config.feature_shape,
max_steps=config.max_steps,
low_range=config.low_range,
high_range=config.high_range,
max_delta=config.max_delta,
max_tension_change=config.max_tension_change,
)
rl_trainer.train(
env=rl_trainer.env,
device="cuda",
buffer_size=config.buffer_size,
learning_rate=config.learning_rate,
batch_size=config.batch_size,
tau=config.tau,
timesteps=config.timesteps,
)
rl_trainer.save_agent(save_path=config.agent_path)