focal.xgboost_pipeline module

Main module for all logic related to XGBoost.

Includes classes for training and predicting.

class focal.xgboost_pipeline.XGBoostModel(csv_path: str, cnn_model_path: str, train_ds: Any, test_ds: Any)[source]

Bases: object

This class provides basic logic for training the XGBoost regressor.

get_model()[source]
load(model_path: str)[source]

Load model from path.

plot_metrics(title: str, metric1, metric2, metric1_label: str, metric2_label: str, x_label: str, y_label: str) None[source]

Basic plotting function for viewing metrics.

Parameters:
  • title – Title of metric plot

  • metric1 – First metric to plot

  • metric2 – Second metric to plot

  • metric1_label – Metric 1 identifying label

  • metric2_labe – Metric 2 identifying label

  • x_label – X-axis label

  • y-Label – Y_axis label

save(save_path: str) None[source]

Saves xgboost model.

Parameters:

save_path – Path to save model.

Raises:

ValueError – If trained model is None.

train(error_type='reg:squarederror', n_estimators: int | None = 200, learning_rate: float | None = 0.05, max_depth: int | None = 4, random_state: int | None = 42, gamma: float | None = 0.0, subsample: float | None = 1.0, reg_lambda: float | None = 1.0)[source]

Training logic for the xgboost regression model.

Parameters:
  • n_estimators – Maximum number of trees to use during training

  • learning_rate – Learning rate to update weights during training

  • max_depth – Maximum tree depth during training

  • random_state – Controls random state of model to ensure consitency across models

  • gamma – Minimum loss reduction.

  • subsample – Fraction of observations used for each tree.

  • reg_lambda – L2 regularization of leaf nodes.

Returns:

Trained xgboost model

class focal.xgboost_pipeline.XGBoostPredictor(csv_path: str, cnn_model_path: str, angle_threshold: float, diameter_threshold: float, xgb_path: str | None = None, scaler_path: str | None = None)[source]

Bases: object

This class implements basic logic for predicting and testing the change in tensions.

load()[source]

Load trained model and scaler.

predict()[source]

Run tension predictions on filtered cleave data.