focal.rl_pipeline module¶
Custom Gym environment that simulates cleaving fibers using a surrogate CNN model.
The agent adjusts tension over multiple steps to achieve optimal cleave quality, which is evaluated via a CNN surrogate model. Observations include fiber context and tension; rewards are based on CNN predictions and a guassian reward function when predicted tensions is close to optimal value.
- class focal.rl_pipeline.CleaveEnv(csv_path: str, cnn_path: str, img_folder: str, feature_shape: List[int], threshold: float, max_steps: int, low_range: float, high_range: float, max_delta: float, max_tension_change: float, quality_weight=100.0, proximity_weight=50.0, scale=25.0)[source]¶
Bases:
EnvCreates the simulated cleave enviornment.
- load_process_images(filename: str) Tensor[source]¶
Load and preprocess image from file path.
- Parameters:
filename – Image filename or path
- Returns:
Preprocessed image tensor
- Return type:
tf.Tensor
- metadata: dict[str, Any] = {'render_modes': ['human']}¶
- render(action: ndarray, cnn_pred: float, reward: float) None[source]¶
Render the environment’s current state in human-readable format.
- Parameters:
action (np.ndarray) – The action taken (as a 1D array).
cnn_pred (float) – The CNN’s predicted cleave quality.
reward (float) – The reward received after the action.
- reset(seed: int | None = None, options: dict | None = None) Tuple[ndarray, Dict[str, Any]][source]¶
Reset the environment to an initial state.
- Parameters:
seed (Optional[int]) – Random seed for reproducibility.
options (Optional[dict]) – Additional options for reset (unused).
- Returns:
Initial observation and empty info dict.
- Return type:
Tuple[np.ndarray, dict]
- step(action: Any) Tuple[ndarray, float, bool, bool, Dict[str, Any]][source]¶
Take a step in the environment using the given action.
- Parameters:
action (gym.ActType) – A 1D array-like action representing tension adjustment.
- Returns:
observation (np.ndarray): The next observation.
reward (float): The reward received after taking the action.
terminated (bool): True if the episode ends successfully.
truncated (bool): True if the episode is truncated (max steps reached).
info (dict): Additional info (empty by default).
- Return type:
Tuple
- class focal.rl_pipeline.TestAgent(csv_path: str, cnn_path: str, img_folder: str, agent_path: str, feature_shape: List[int], threshold: float, max_steps: int, low_range: float, high_range: float, max_delta: float, max_tension_change: float)[source]¶
Bases:
object
- class focal.rl_pipeline.TrainAgent(csv_path: str, cnn_path: str, img_folder: str, threshold: float, feature_shape: List[int], max_steps: int, low_range: float, high_range: float, max_delta: float, max_tension_change: float)[source]¶
Bases:
objectClass for training the RL agent
- train(env: Env, device: str, buffer_size: int, learning_rate: float, batch_size: int, tau: float, timesteps: int) None[source]¶
Train the agent using Soft Actor Critic algo.
- Parameters:
env (gym.Env) – simulated training enviornment
device (str) – cuda to use GPU
buffer_size (int) – replay buffer size
learning_rate (float) – typical learning rate for ml
batch_size (int) – number of episodes to batch together
tau (float) – Soft update coefficient
timesteps (int) – Total number of training steps