Environment Designers
The Gym-TORAX package provides a powerful framework for plasma physics specialists to explore and create new environments. It is designed as a flexible toolkit that allows users with a background in plasma physics to develop and test their own scenarios, configurations, and experiments. This includes defining new actions, observations, the TORAX configuration (for the plasma properties) and rewards.
Creation of New Environments
Base Environment
- class gymtorax.envs.base_env.BaseEnv(render_mode: str | None = None, log_level: str = 'warning', log_file: str | None = None, plot_config: FigureProperties | str = 'default', store_history: bool = False)[source]
-
Abstract base class for TORAX plasma simulation environments.
This class integrates TORAX physics simulations with the Gymnasium reinforcement learning framework, providing a standardized interface for plasma control tasks. It handles the complexities of time discretization, simulation management, and action/observation space construction.
The environment operates by:
Setting up logging configuration for debugging and monitoring
Initializing TORAX configuration and simulation state
Managing discrete time steps with configurable time intervals
Applying actions by updating TORAX configuration parameters
Executing simulation steps and extracting observations
Computing rewards and determining episode termination
- Variables:
observation_handler (Observation) – Handles observation space and data extraction
action_handler (ActionHandler) – Manages action space and parameter updates
config (ConfigLoader) – TORAX configuration manager
torax_app (ToraxApp) – TORAX simulation wrapper
state (dict) – Current complete plasma state
observation (dict) – Current filtered observation
T (float) – Total simulation time [s]
delta_t_a (float) – Time interval between actions [s]
current_time (float) – Current simulation time [s]
timestep (int) – Current timestep counter
terminated (bool) – Episode termination flag
truncated (bool) – Episode truncation flag
- Parameters:
Note
Subclasses must implement these abstract methods:
_define_observation_space: Define observation space variables_define_action_space: Define available control actions_get_torax_config: Define TORAX configuration parameters_compute_reward: Define reward signal (optional override)
- __init__(render_mode: str | None = None, log_level: str = 'warning', log_file: str | None = None, plot_config: FigureProperties | str = 'default', store_history: bool = False) None[source]
Initialize the TORAX gymnasium environment.
This method sets up the complete simulation environment including TORAX configuration, action/observation spaces, time discretization, and rendering components.
- Parameters:
render_mode (str or None) – Rendering mode for visualization. Options:
"human","rgb_array", orNone. Defaults toNone.log_level (str) – Logging level for environment operations. Options:
"debug","info","warning","error","critical". Defaults to"warning".log_file (str or None) – Path to log file for writing log messages. If
None, logs to console. Defaults toNone.plot_config (str or FigureProperties) – Name of the plot configuration to use (e.g.,
"default"). Can also be a torax FigureProperties instance for custom plot configuration.store_history (bool) – Whether to store simulation history for later saving. Set to
Trueif you plan to usesave_filemethod. Defaults toFalse.
- Raises:
ValueError – If required parameters are missing for chosen discretization method.
TypeError – If
discretization_toraxis not"auto"or"fixed".KeyError – If required keys are missing from TORAX configuration.
- Return type:
None
Note
Subclasses should use
**kwargsto pass parameters to avoid explicit parameter listing and maintain flexibility as the base class evolves. Environment-specific defaults can be set usingkwargs.setdefault()before callingsuper().__init__().The environment must implement the abstract methods
_define_observation_space,_define_action_space,_get_torax_config, and_compute_reward.
- reset(*, seed: int | None = None, options: dict[str, Any] | None = None) tuple[dict[str, Any], dict[str, Any]][source]
Reset the environment to its initial state for a new episode.
This method initializes a new simulation episode by:
Resetting internal counters and flags
Starting the TORAX simulation from initial conditions
Extracting the initial observation state
Optionally rendering the initial state
- Parameters:
seed (int or None) – Random seed for reproducible episode initialization. Used to seed the environment’s random number generator for deterministic behavior across resets. If
None, no seeding is performed. Defaults toNone.options (dict[str, Any] or None) – Additional options for environment reset. Currently unused but maintained for Gymnasium compatibility. Defaults to
None.
- Returns:
observation (dict): Initial observation of plasma state
info (dict): Additional information (empty dict)
- Return type:
- step(action: dict[str, ndarray[tuple[int, ...], dtype[floating]]]) tuple[dict[str, Any], float, bool, bool, dict[str, Any]][source]
Execute one environment step with the given action.
This method implements the core RL interaction by:
Capturing the current state before action
Applying the action to update TORAX configuration
Running the simulation for one time interval
Extracting the new observation state
Computing the reward signal
Checking for episode termination
Updating time counters
- Parameters:
action (dict[str, numpy.ndarray]) – Action dictionary containing parameter values for all configured actions.
- Returns:
observation (dict): New plasma state observation
reward (float): Reward signal for this step
terminated (bool):
Trueif episode ended due to terminal conditiontruncated (bool):
Trueif episode ended due to time/step limitsinfo (dict): Additional step information
- Return type:
- render() ndarray | None[source]
Render the current environment state following Gymnasium convention.
- Returns:
RGB array of shape (height, width, 3) if render_mode is “rgb_array” None: If render_mode is “human” or renderer is not available
- Return type:
- save_file(file_name)[source]
Save the simulation output data to a file.
This method saves the complete simulation history to a specified file. The simulation must have been initialized with store_history=True for this method to work properly.
- Parameters:
file_name (str) – The path and filename where the output should be saved. The file format is typically NetCDF (.nc extension).
- Raises:
ctypes.ArgumentError – If the environment was created without store_history=True.
RuntimeError – If there was an error during the save operation.
Abstract Methods
- abstract BaseEnv._define_action_space() list[Action][source]
Define the available control actions for this environment.
This method must be implemented by concrete subclasses to specify which plasma parameters can be controlled by the RL agent.
- Returns:
- List of Action instances representing controllable
parameters with their bounds and TORAX configuration mappings.
- Return type:
Example
>>> def _define_action_space(self): ... return [ ... IpAction(min=[0.5e6], max=[2.0e6]), # Plasma current ... EcrhAction( # ECRH heating ... min=[0.0, 0.0, 0.0], # [power, loc, width] ... max=[10e6, 1.0, 0.5] ... ), ... NbiAction() # NBI with defaults ... ]
- abstract BaseEnv._define_observation_space() Observation[source]
Define the observation space variables for this environment.
This method must be implemented by concrete subclasses to specify which TORAX variables should be included in the observation space.
- Returns:
- Configured observation handler that defines which
plasma state variables are visible to the RL agent.
- Return type:
Example
>>> def _define_observation_space(self): ... return AllObservation( ... exclude=["n_impurity", "Z_impurity"], ... custom_bounds={ ... "T_e": (0.0, 50.0), # Temperature range in keV ... "T_i": (0.0, 50.0) ... } ... )
- abstract BaseEnv._get_torax_config() dict[str, Any][source]
Define the TORAX simulation configuration.
This abstract method must be implemented by concrete subclasses which provides the necessary parameters for the TORAX simulation, including its core configuration, the time discretization method, the control time step, and the ratio between simulation and control time steps.
- Returns:
- A dictionary containing the TORAX configuration.
The dictionary must have the following keys:
"config"(dict): A dictionary of TORAX configuration parameters."discretization"(str): The time discretization method. Options are"auto"(uses'delta_t_a') or"fixed"(uses'ratio_a_sim')."ratio_a_sim"(int or None): The ratio of action timesteps to simulation timesteps. Required if'discretization'is"fixed"."delta_t_a"(float or None): The time interval between actions in seconds. Required if'discretization'is"auto".
- Return type:
Example
>>> def _get_torax_config(self): ... return { ... "config": TORAX_CONFIG, ... "discretization": "auto", ... "delta_t_a": 0.05, # 50 ms between actions ... # "ratio_a_sim": 10, # Only needed if using "fixed" discretization ... }
- abstract BaseEnv._compute_reward(state: dict[str, Any], next_state: dict[str, Any], action: dict[str, ndarray[tuple[int, ...], dtype[floating]]]) float[source]
Define the reward signal for a state transition.
This method should be overridden by concrete subclasses to implement task-specific reward functions. The default implementation returns 0.0.
- Parameters:
state (dict[str, Any]) – Previous plasma state before action was applied. Contains complete state with
"profiles"and"scalars"dictionaries.next_state (dict[str, Any]) – New plasma state after action and simulation step. Same structure as state parameter.
action (dict[str, numpy.ndarray]) – Action dictionary that was applied to cause this transition.
- Returns:
Reward value for this state transition.
- Return type:
Example
>>> def _compute_reward(self, state, next_state, action): ... # Reward based on proximity to target beta_N ... target_beta = 2.0 ... current_beta = next_state["scalars"]["beta_N"] ... return -abs(current_beta - target_beta)
Here is a simple example of how to create a new environment by extending the base class:
from gymtorax.envs.base_env import BaseEnv
import gymtorax.action_handler as ah
import gymtorax.observation_handler as oh
import gymtorax.rewards as rw
class CustomEnv(BaseEnv):
def _define_action_space(self):
actions = [ah.IpAction(),]
return actions
def _define_observation_space(self):
return oh.AllObservation()
def _get_torax_config(self):
return {"config": CONFIG,
"discretization": "auto",
"delta_t_a": 1.0}
def _compute_reward(self, current_state, next_state, action):
Q = rw.get_fusion_gain(next_state)
q_min = rw.get_q_min(next_state)
w_Q, w_qmin = 1.0, 1.0
def q_min_function():
if q_min <= 1:
return 0
elif q_min > 1:
return 1
return w_Q * Q + w_qmin * q_min_function()
Creation of Actions
- class gymtorax.action_handler.Action(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64)[source]
Bases:
ABCAbstract base class for all TORAX simulation actions.
An action represents a controllable parameter or set of parameters that can influence plasma behavior. Each action has bounds, current values, and knows how to map itself to TORAX configuration dictionaries.
This class is designed to be extended by users to create custom actions for specific control parameters. Subclasses must define the class attributes to specify the action dimensionality, bounds, and configuration mapping.
- Class Attributes:
name (str) – Unique identifier for this action type
dimension (int) – Number of parameters controlled by this action
default_min (list[float]) – Default minimum values for parameters
default_max (list[float]) – Default maximum values for parameters
config_mapping (dict[tuple[str, …], tuple[int, float]]) – Mapping from configuration paths to parameter indices and scaling factors. Keys are tuples representing the nested path in the config dictionary, values are tuples of
(parameter_index, scaling_factor).state_var (tuple[tuple[str, …], …]) – Tuple of tuples specifying the state variables directly modified by this action. Each inner tuple contains the path to a state variable (e.g.,
('scalars', 'Ip')or('profiles', 'p_ecrh_e')).
- Variables:
ramp_rate (numpy.ndarray) – Ramp rate limits for each parameter.
numpy.infindicates no ramp rate limit for that parameter.dtype (numpy.dtype) – NumPy data type for action arrays (default:
np.float64)
- Parameters:
Example
Create a custom action for controlling two parameters:
>>> class TwoParamAction(Action): ... name = "CustomTwoParam" ... dimension = 2 ... default_min = [0.0, -5.0] ... default_max = [10.0, 5.0] ... default_ramp_rate = [0.5, None] # First param limited, second unlimited ... config_mapping = { ... ('section', 'param1'): (0, 1), # No scaling ... ('section', 'param2'): (1, 0.5) # Scale by 0.5 ... } ... state_var = {'scalars': ['param1', 'param2']} >>> action = TwoParamAction()
- dimension: int
- name: str
- __init__(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64) None[source]
Initialize an Action instance.
- Parameters:
min (list[float] | None) – Custom minimum bounds for each parameter. If
None, uses the classdefault_minvalues. Must have length equal todimension.max (list[float] | None) – Custom maximum bounds for each parameter. If
None, uses the classdefault_maxvalues. Must have length equal todimension.ramp_rate (list[float | None] | None) – Custom ramp rate limits for each parameter. If
None, uses the classdefault_ramp_ratevalues. Must have length equal todimension. Each element can beNone(no limit) or a float (max change per step).dtype (dtype) – NumPy data type for the action arrays (default:
np.float64). Used for creating action spaces.
- Raises:
ValueError – If
nameclass attribute is not definedValueError – If
dimensionclass attribute is not defined or not a positive integerValueError – If
config_mappingclass attribute is not definedValueError – If
default_min,default_max, ordefault_ramp_ratedo not match the dimensionValueError – If provided
min,max, orramp_ratedo not match the dimension
- Return type:
None
- property min: ndarray[tuple[int, ...], dtype[floating]]
Minimum bounds for this action parameters.
- Returns:
- Array of minimum values, one for each parameter
controlled by this action.
- Return type:
- property max: ndarray[tuple[int, ...], dtype[floating]]
Maximum bounds for this action parameters.
- Returns:
- Array of maximum values, one for each parameter
controlled by this action.
- Return type:
- init_dict(config_dict: dict[str, Any]) None[source]
Initialize a TORAX configuration dictionary with this action parameters.
This method sets up the configuration dictionary with the action current values at
time=0, creating the proper time-dependent parameter structure expected by TORAX.- Parameters:
config_dict (dict[str, Any]) – The TORAX configuration dictionary to initialize. Must have the nested structure expected by this action config_mapping.
- Raises:
KeyError – If the configuration dictionary does not have the expected structure for this action parameters.
RuntimeError – If any error occurs during the initialization process.
- Return type:
None
- update_to_config(config_dict: dict[str, Any], time: float) None[source]
Update a TORAX configuration dictionary with new action values.
This method updates the time-dependent parameters in the configuration dictionary with the action current values at the specified time. Scaling factors from
config_mappingare applied consistently.- Parameters:
- Return type:
None
Note
The configuration dictionary must have been initialized with
init_dictbefore calling this method. Values are scaled by the factors defined inconfig_mappingbefore being stored.
Creation of Observations
- class gymtorax.observation_handler.Observation(variables: dict[str, list[str]] | None = None, custom_bounds_filename: str | None = None, exclude: dict[str, list[str]] | None = None, dtype: dtype = np.float64)[source]
Bases:
ABCAbstract base class for building observation spaces from TORAX DataTree outputs.
Converts TORAX simulation outputs into structured observation spaces for reinforcement learning environments. Handles variable selection, bounds specification, and automatic action variable filtering.
- Variables:
variables_to_include (dict) – Variables to include in observation space.
variables_to_exclude (dict) – Variables to exclude from observation space.
custom_bounds (dict) – Custom bounds for variables.
dtype (numpy.dtype) – Data type for observation arrays.
action_variables (dict) – Variables controlled by actions.
state_variables (dict) – Available variables from TORAX output.
observation_variables (dict) – Final filtered observation variables.
bounds (dict) – Final bounds after processing.
- Parameters:
- __init__(variables: dict[str, list[str]] | None = None, custom_bounds_filename: str | None = None, exclude: dict[str, list[str]] | None = None, dtype: dtype = np.float64) None[source]
Initialize Observation handler.
Sets up configuration for the observation handler. Requires subsequent calls to set_state_variables(), set_action_variables(), and build_observation_space() before use.
- Parameters:
variables (dict[str, list[str]] | None) – Variables to include. Format:
{"profiles": [names], "scalars": [names]}. IfNone, includes all available variables except those inexclude.custom_bounds_filename (str | None) – Path to JSON file with custom bounds. Format:
{"profiles": {var: {"min": val, "max": val}}, "scalars": {...}}.exclude (dict[str, list[str]] | None) – Variables to exclude. Format:
{"profiles": [names], "scalars": [names]}. Cannot be used withvariablesparameter.dtype (dtype) – Data type for observation arrays.
- Raises:
ValueError – If both
variablesandexcludespecified or invalid configuration.- Return type:
None
- set_action_variables(variables: dict[str, list[str]]) None[source]
Set variables controlled by actions.
These variables are removed from the observation space to prevent redundancy between actions and observations.
- set_state_variables(state: DataTree) None[source]
Set available state variables from TORAX output.
Catalogs all variables from the TORAX DataTree for inclusion in the observation space.
- Parameters:
state (DataTree) – TORAX DataTree with
/profiles/and/scalars/datasets.- Return type:
None
- extract_state_observation(datatree: DataTree) tuple[dict[str, dict[str, ndarray]], dict[str, dict[str, ndarray]]][source]
Extract state and observation data from TORAX output.
Returns both complete state (all variables) and filtered observation (selected variables only).
- Parameters:
datatree (DataTree) – TORAX simulation output containing profiles and scalars datasets.
- Returns:
- state (dict): Complete state with all variables.
Format
{"profiles": {var: array}, "scalars": {var: value}}
- observation (dict): Filtered observation with selected variables.
Format
{"profiles": {var: array}, "scalars": {var: value}}
- Return type:
tuple[dict[str, dict[str, numpy.ndarray]], dict[str, dict[str, numpy.ndarray]]]
- build_observation_space() Dict[source]
Build Gymnasium observation space for selected variables.
Creates nested Dict space with Box spaces for each variable using configured bounds. Validates configuration and finalizes variable selection.
- Returns:
- Gymnasium Dict space with structure
{"profiles": {var: Box}, "scalars": {var: Box}}
- Return type:
- Raises:
ValueError – If validation fails or required setup incomplete.
Creation of Configurations
The configuration needed in the method _define_torax_config is a dictionary which is exactly the same as the one used in TORAX. You can find more details about the configuration in the TORAX documentation. The configuration used in our example is available here.
Creation of Rewards
TORAX reward module.
This module provides functions to extract specific metrics from the state dictionary returned by the TORAX simulator. These metrics can be used to construct reward functions for reinforcement learning environments focused on tokamak control and optimization. Other reward functions can be created.
- gymtorax.rewards.get_fusion_gain(state: dict) float[source]
Get the fusion gain \(Q\) from the state dictionary.
- gymtorax.rewards.get_beta_N(state: dict) float[source]
Get the normalized \(\beta_N\) from the state dictionary.
- gymtorax.rewards.get_tau_E(state: dict) float[source]
Get the energy confinement time \(\tau_E\) from the state dictionary.
- gymtorax.rewards.get_h98(state: dict) float[source]
Get the H-mode confinement quality factor from the state dictionary.
- gymtorax.rewards.get_q_profile(state: dict) ndarray[source]
Get the safety factor profile \(q\) from the state dictionary.
- Parameters:
state (dict) – The state dictionary containing profile values.
- Returns:
The safety factor profile \(q\).
- Return type:
- gymtorax.rewards.get_q_min(state: dict) float[source]
Get the minimum safety factor \(q_{min}\) from the state dictionary.
- gymtorax.rewards.get_q95(state: dict) float[source]
Get safety factor at 95% of the normalized poloidal flux coordinate.