"""Custom Gym environment that simulates cleaving fibers using a surrogate CNN model.
The agent adjusts tension over multiple steps to achieve optimal cleave quality,
which is evaluated via a CNN surrogate model. Observations include fiber context and
tension; rewards are based on CNN predictions and a guassian reward function when predicted
tensions is close to optimal value.
"""
import os
from typing import Any, Dict, List, Optional, Tuple
import gymnasium as gym
import joblib
import numpy as np
import pandas as pd
import tensorflow as tf
from gymnasium import spaces
from stable_baselines3 import SAC
from stable_baselines3.common.env_checker import check_env
[docs]
class CleaveEnv(gym.Env):
"""Creates the simulated cleave enviornment."""
# use human readable mode
metadata = {"render_modes": ["human"]}
def __init__(
self,
csv_path: str,
cnn_path: str,
img_folder: str,
feature_shape: List[int],
threshold: float,
max_steps: int,
low_range: float,
high_range: float,
max_delta: float,
max_tension_change: float,
quality_weight=100.0,
proximity_weight=50.0,
scale=25.0,
) -> None:
"""
Initialize the environment.
Args:
csv_path (str): Path to the CSV file with cleave data.
cnn_path (str): Path to the trained CNN model used as a surrogate.
img_folder (str): Directory containing cleave images.
feature_shape (List): Shape of the numercal features
threshold: (float): Classification threshold for good cleave.
max_steps (int): Maximum number of steps to use per episode.
low_range (float): Low percentage of maximum tension.
high_range (float): High percentage of maximum tension.
max_delta (float): Maximum change in tension per action
max_tension_change (float): Absolute maximum tension change.
"""
# call gym init method
super().__init__()
self.cnn_model = joblib.load(cnn_path)
self.img_folder = img_folder
self.df = pd.read_csv(csv_path)
self.feature_shape = feature_shape
self.threshold = threshold
self.low_range = low_range
self.high_range = high_range
self.max_delta = max_delta
self.max_tension_change = max_tension_change
self.QUALITY_WEIGHT = quality_weight
self.PROXIMITY_WEIGHT = proximity_weight
self.SCALE = scale
filtered_df = self.df[self.df["CNN_Predicition"] == 1]
# calculate ideal tensions by fiber type
self.ideal_tensions = dict(
filtered_df.groupby("FiberType")["CleaveTension"]
.mean()
.astype(np.float32)
)
len_fibers = len(self.df["FiberType"].unique())
# one hot encode fiber names
self.df = pd.get_dummies(
self.df, columns=["FiberType"], dtype=np.int32
)
self.action_space = spaces.Box(
low=-1.0, high=1.0, shape=(1,), dtype=np.float32
)
self.max_tension_change = max_tension_change
fiber_types = self.df.iloc[:, -len_fibers:]
self.model_features = self.cnn_model.feature_names_in_
other_inputs = self.df["Diameter"]
# combine fiber names and diameter
combined_df = pd.concat([other_inputs, fiber_types], axis=1)
self.context_df = combined_df
# + 3 for current tension, current reward, and last tension
observations_total = 3 + len(self.context_df.columns)
self.observation_space = spaces.Box(
low=-np.inf,
high=np.inf,
shape=(observations_total,),
dtype=np.float32,
)
self.max_steps = max_steps
self.current_step = 0
self.current_context = None
self.current_tension = 0
self.render_mode = None
# For use if you decide to include full CNN in enviornment
[docs]
def load_process_images(self, filename: str) -> "tf.Tensor":
"""Load and preprocess image from file path.
Args:
filename: Image filename or path
Returns:
tf.Tensor: Preprocessed image tensor
"""
if tf is None:
raise ImportError("TensorFlow is required for image processing")
def load_image(file):
"""Load an image and process using same preprocessing as backbone.
Args:
file: path to image
preprocess_input: processing from backbone model
Returns:
loaded and resized image
"""
full_path = os.path.join(self.img_folder, file)
try:
img_raw = tf.io.read_file(full_path)
except FileNotFoundError:
print(f"Image file not found: {full_path}")
return None
except Exception as e:
print(f"Error loading image {full_path}: {e}")
return None
try:
img = tf.image.decode_png(img_raw, channels=1)
img = tf.image.resize(img, [224, 224])
img = tf.image.grayscale_to_rgb(img)
return img
except Exception as e:
print(f"Error processing image {full_path}: {e}")
return None
img = load_image(filename)
img.set_shape([224, 224, 3])
return img
[docs]
def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""
Reset the environment to an initial state.
Args:
seed (Optional[int]): Random seed for reproducibility.
options (Optional[dict]): Additional options for reset (unused).
Returns:
Tuple[np.ndarray, dict]: Initial observation and empty info dict.
"""
super().reset(seed=seed)
self.last_reward = 0.0
self.current_context = self.context_df.sample(
n=1, random_state=self.np_random
)
self.current_ideal_tension = self.ideal_tensions[
self._get_current_fiber_type()
]
self.current_tension = self.np_random.uniform(
low=self.current_ideal_tension * (self.low_range),
high=self.current_ideal_tension * (self.high_range),
)
self.current_step = 0
observation = self._create_observation()
if self.render_mode == "human":
print("\n---------------EPISODE RESET----------------------")
print(
f"New Scenario: Fiber = {self._get_current_fiber_type()} Start Tension = {self.current_tension:.0f}"
)
fiber_type = self._get_current_fiber_type()
# Log info
info = {
"fiber_type": fiber_type,
"start_tension": self.current_tension,
}
return observation, info
[docs]
def step(
self,
action: Any,
) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
"""Take a step in the environment using the given action.
Args:
action (gym.ActType): A 1D array-like action representing tension adjustment.
Returns:
Tuple:
- observation (np.ndarray): The next observation.
- reward (float): The reward received after taking the action.
- terminated (bool): True if the episode ends successfully.
- truncated (bool): True if the episode is truncated (max steps reached).
- info (dict): Additional info (empty by default).
"""
delta_tension = float(action[0] * self.max_tension_change)
self.current_tension = self.current_tension + delta_tension
# Compute min and max tensions
min_tension = self.current_ideal_tension * self.low_range
max_tension = self.current_ideal_tension * self.high_range
# Clip tensions if outside range
self.current_tension = np.clip(
self.current_tension, min_tension, max_tension
)
self.current_ideal_tension = self.ideal_tensions[
self._get_current_fiber_type()
]
# Increment steps in episode
self.current_step += 1
model_inputs = self.current_context.copy()
model_inputs["CleaveTension"] = self.current_tension
model_inputs = model_inputs[self.model_features]
tension_error = self.current_tension - self.current_ideal_tension
# Get CNN surrogate prediction in range (0, 1)
cnn_pred = self.cnn_model.predict_proba(model_inputs)[0, 1]
reward = 0.0
terminated = False
reward += self.QUALITY_WEIGHT * cnn_pred
if cnn_pred >= self.threshold:
reward += 50.0
terminated = True
scale = self.SCALE
# Gaussian reward for proximity to current ideal tension
proximity_reward = self.PROXIMITY_WEIGHT * np.exp(
-(tension_error**2) / (2 * scale**2)
)
reward += proximity_reward
# Decrease reward if close to max or min tension
if np.isclose(self.current_tension, min_tension) or np.isclose(
self.current_tension, max_tension
):
reward -= 75.0
# Decrease reward if tension change is opposite to direction of ideal tension
if (tension_error > 0) and (action[0] > 0):
reward -= 25.0
elif (tension_error < 0) and (action[0] < 0):
reward -= 25.0
else:
reward += 5.0
reward -= 1.0
reward -= (abs(action[0]) ** 2) * 2.0
truncated = self.current_step >= self.max_steps
if truncated and not terminated:
reward -= self.PROXIMITY_WEIGHT * (1.0 - cnn_pred)
if self.render_mode == "human":
self.render(action, cnn_pred, reward)
observation = self._create_observation()
# Log info
info = {
"cnn_pred": float(cnn_pred),
"current_tension": round(float(self.current_tension), 3),
"current_ideal_tension": round(
float(self.current_ideal_tension), 3
),
"tension_error": round(float(tension_error), 3),
"action": round(float(action) * self.max_tension_change, 3),
}
self.last_reward = reward
return observation, float(reward), terminated, truncated, info
def _get_current_fiber_type(self) -> str:
"""Get the name of the current fiber type from the one-hot encoded context.
Returns:
str: The name of the current fiber type, or 'Unknown' if not found.
"""
for col_name in self.current_context.columns:
if (
"FiberType_" in col_name
and self.current_context[col_name].iloc[0] == 1.0
):
return col_name.replace("FiberType_", "")
return "Unknown"
def _create_observation(self) -> np.ndarray:
"""Create a numeric observation vector from the current state.
Returns:
np.ndarray: Concatenation of current tension and context values.
"""
tension_error = self.current_ideal_tension - self.current_tension
return np.concatenate(
[
[self.current_tension],
[tension_error],
self.current_context.values[0],
np.array([self.last_reward]),
]
).astype(np.float32)
[docs]
def render(
self, action: np.ndarray, cnn_pred: float, reward: float
) -> None:
"""Render the environment's current state in human-readable format.
Args:
action (np.ndarray): The action taken (as a 1D array).
cnn_pred (float): The CNN's predicted cleave quality.
reward (float): The reward received after the action.
"""
action_str = f"{(action[0] *self.max_tension_change):+.2f}"
cnn_str = "GOOD" if cnn_pred > self.threshold else "BAD"
print(
f"Step {self.current_step:2d} Tension: {self.current_tension:6.1f} (Action: {action_str:6s}) -> CNN: {cnn_str:4s}| Reward: {reward:6.1f}"
)
[docs]
class TrainAgent:
"""Class for training the RL agent"""
def __init__(
self,
csv_path: str,
cnn_path: str,
img_folder: str,
threshold: float,
feature_shape: List[int],
max_steps: int,
low_range: float,
high_range: float,
max_delta: float,
max_tension_change: float,
) -> None:
"""
Initialize the training environment for the RL agent.
Args:
csv_path (str): Path to the CSV file with cleave data.
cnn_path (str): Path to the trained CNN model used as a surrogate.
img_folder (str): Directory containing cleave images.
feature_shape (List): Shape of the numercal features
threshold: (float): Classification threshold for good cleave.
max_steps (int): Maximum number of steps to use per episode.
low_range (float): Low percentage of maximum tension.
high_range (float): High percentage of maximum tension.
max_delta (float): Maximum change in tension per action
max_tension_change (float): Absolute maximum tension change.
"""
# Initialize Enviornment
self.env = CleaveEnv(
csv_path=csv_path,
cnn_path=cnn_path,
img_folder=img_folder,
feature_shape=feature_shape,
threshold=threshold,
max_steps=max_steps,
low_range=low_range,
high_range=high_range,
max_delta=max_delta,
max_tension_change=max_tension_change,
)
check_env(self.env)
[docs]
def train(
self,
env: gym.Env,
device: str,
buffer_size: int,
learning_rate: float,
batch_size: int,
tau: float,
timesteps: int,
) -> None:
"""Train the agent using Soft Actor Critic algo.
Args:
env (gym.Env): simulated training enviornment
device (str): cuda to use GPU
buffer_size (int): replay buffer size
learning_rate (float): typical learning rate for ml
batch_size (int): number of episodes to batch together
tau (float): Soft update coefficient
"""
self.agent = SAC(
"MlpPolicy",
env=self.env,
device=device,
verbose=0,
buffer_size=buffer_size,
ent_coef=0.3, # Coefficient for maximum entropy
learning_rate=learning_rate,
batch_size=batch_size,
tau=tau,
)
self.agent.learn(total_timesteps=timesteps, progress_bar=True)
[docs]
def save_agent(self, save_path: str) -> None:
self.agent.save(save_path)
[docs]
class TestAgent:
def __init__(
self,
csv_path: str,
cnn_path: str,
img_folder: str,
agent_path: str,
feature_shape: List[int],
threshold: float,
max_steps: int,
low_range: float,
high_range: float,
max_delta: float,
max_tension_change: float,
) -> None:
"""
Initialize the environment and load a trained agent.
Args:
csv_path (str): Path to the CSV file with cleave data.
cnn_path (str): Path to the trained CNN model used as a surrogate.
img_folder (str): Directory containing cleave images.
feature_shape (List): Shape of the numercal features
threshold: (float): Classification threshold for good cleave.
max_steps (int): Maximum number of steps to use per episode.
low_range (float): Low percentage of maximum tension.
high_range (float): High percentage of maximum tension.
max_delta (float): Maximum change in tension per action
max_tension_change (float): Absolute maximum tension change.
"""
self.env = CleaveEnv(
csv_path=csv_path,
cnn_path=cnn_path,
img_folder=img_folder,
threshold=threshold,
feature_shape=feature_shape,
max_steps=max_steps,
low_range=low_range,
high_range=high_range,
max_delta=max_delta,
max_tension_change=max_tension_change,
)
self.env.render_mode = "human"
self.agent = SAC.load(agent_path)
[docs]
def test_agent(self, episodes: int) -> Dict:
"""Test the trained RL agent on random episodes.
Args:
episodes (int): total number of episodes to test agent
"""
all_episode_info = []
for episode in range(episodes):
obs, info_reset = self.env.reset()
done = False
episode_reward = 0
episode_info = []
rewards = []
observations = []
while not done:
action, _ = self.agent.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = self.env.step(
action
)
episode_info.append(info)
rewards.append(round(reward, 3))
observations.append(obs)
episode_reward += reward
done = terminated or truncated
print(
f"Episode {episode + 1} finished with a total reward of: {episode_reward:.2f}"
)
# Log metrics
metrics = {
"start tension": info_reset["start_tension"],
"fiber type": info_reset["fiber_type"],
"episode info": episode_info,
"rewards": rewards,
"episode reward": round(episode_reward, 3),
}
all_episode_info.append(metrics)
self.env.close()
return all_episode_info