Environment Design
This section describes the core environment components that make up GymTorax: environments, actions, observations, and rewards. These are the building blocks for creating plasma control tasks.
Base Environment
TORAX Base Environment Module.
This module provides the abstract base class for TORAX plasma simulation environments compatible with the Gymnasium reinforcement learning framework. It integrates TORAX physics simulations with RL interfaces, handling time discretization, action/observation spaces, and the simulation lifecycle.
The BaseEnv class serves as a foundation for creating specific plasma control tasks by:
Managing TORAX configuration and simulation execution
Defining action and observation space structures
Handling time discretization and episode management
Providing hooks for custom reward functions and terminal conditions
Configurable logging system for debugging and monitoring
- Classes:
BaseEnv: Abstract base class for TORAX Gymnasium environments
Example
Create a custom environment by extending BaseEnv:
>>> class PlasmaControlEnv(BaseEnv):
... def __init__(self, render_mode=None, ``**kwargs``):
... # Set environment-specific defaults
... kwargs.setdefault("log_level", "info")
... super().__init__(render_mode=render_mode, ``**kwargs``)
...
... def _define_observation_space(self):
... return AllObservation(exclude=["n_impurity"])
...
... def _define_action_space(self):
... return [IpAction(), EcrhAction()]
...
... def _get_torax_config(self):
... return CONFIG
...
... def _compute_reward(self, state, next_state, action):
... # Custom reward logic
... return -abs(next_state["scalars"]["beta_N"] - 2.0)
- class gymtorax.envs.base_env.BaseEnv(render_mode: str | None = None, log_level: str = 'warning', log_file: str | None = None, plot_config: FigureProperties | str = 'default', store_history: bool = False)[source]
-
Abstract base class for TORAX plasma simulation environments.
This class integrates TORAX physics simulations with the Gymnasium reinforcement learning framework, providing a standardized interface for plasma control tasks. It handles the complexities of time discretization, simulation management, and action/observation space construction.
The environment operates by:
Setting up logging configuration for debugging and monitoring
Initializing TORAX configuration and simulation state
Managing discrete time steps with configurable time intervals
Applying actions by updating TORAX configuration parameters
Executing simulation steps and extracting observations
Computing rewards and determining episode termination
- Variables:
observation_handler (Observation) – Handles observation space and data extraction
action_handler (ActionHandler) – Manages action space and parameter updates
config (ConfigLoader) – TORAX configuration manager
torax_app (ToraxApp) – TORAX simulation wrapper
state (dict) – Current complete plasma state
observation (dict) – Current filtered observation
T (float) – Total simulation time [s]
delta_t_a (float) – Time interval between actions [s]
current_time (float) – Current simulation time [s]
timestep (int) – Current timestep counter
terminated (bool) – Episode termination flag
truncated (bool) – Episode truncation flag
- Parameters:
Note
Subclasses must implement these abstract methods:
_define_observation_space: Define observation space variables_define_action_space: Define available control actions_get_torax_config: Define TORAX configuration parameters_compute_reward: Define reward signal (optional override)
- __init__(render_mode: str | None = None, log_level: str = 'warning', log_file: str | None = None, plot_config: FigureProperties | str = 'default', store_history: bool = False) None[source]
Initialize the TORAX gymnasium environment.
This method sets up the complete simulation environment including TORAX configuration, action/observation spaces, time discretization, and rendering components.
- Parameters:
render_mode (str or None) – Rendering mode for visualization. Options:
"human","rgb_array", orNone. Defaults toNone.log_level (str) – Logging level for environment operations. Options:
"debug","info","warning","error","critical". Defaults to"warning".log_file (str or None) – Path to log file for writing log messages. If
None, logs to console. Defaults toNone.plot_config (str or FigureProperties) – Name of the plot configuration to use (e.g.,
"default"). Can also be a torax FigureProperties instance for custom plot configuration.store_history (bool) – Whether to store simulation history for later saving. Set to
Trueif you plan to usesave_filemethod. Defaults toFalse.
- Raises:
ValueError – If required parameters are missing for chosen discretization method.
TypeError – If
discretization_toraxis not"auto"or"fixed".KeyError – If required keys are missing from TORAX configuration.
- Return type:
None
Note
Subclasses should use
**kwargsto pass parameters to avoid explicit parameter listing and maintain flexibility as the base class evolves. Environment-specific defaults can be set usingkwargs.setdefault()before callingsuper().__init__().The environment must implement the abstract methods
_define_observation_space,_define_action_space,_get_torax_config, and_compute_reward.
- reset(*, seed: int | None = None, options: dict[str, Any] | None = None) tuple[dict[str, Any], dict[str, Any]][source]
Reset the environment to its initial state for a new episode.
This method initializes a new simulation episode by:
Resetting internal counters and flags
Starting the TORAX simulation from initial conditions
Extracting the initial observation state
Optionally rendering the initial state
- Parameters:
seed (int or None) – Random seed for reproducible episode initialization. Used to seed the environment’s random number generator for deterministic behavior across resets. If
None, no seeding is performed. Defaults toNone.options (dict[str, Any] or None) – Additional options for environment reset. Currently unused but maintained for Gymnasium compatibility. Defaults to
None.
- Returns:
observation (dict): Initial observation of plasma state
info (dict): Additional information (empty dict)
- Return type:
- step(action: dict[str, ndarray[tuple[int, ...], dtype[floating]]]) tuple[dict[str, Any], float, bool, bool, dict[str, Any]][source]
Execute one environment step with the given action.
This method implements the core RL interaction by:
Capturing the current state before action
Applying the action to update TORAX configuration
Running the simulation for one time interval
Extracting the new observation state
Computing the reward signal
Checking for episode termination
Updating time counters
- Parameters:
action (dict[str, numpy.ndarray]) – Action dictionary containing parameter values for all configured actions.
- Returns:
observation (dict): New plasma state observation
reward (float): Reward signal for this step
terminated (bool):
Trueif episode ended due to terminal conditiontruncated (bool):
Trueif episode ended due to time/step limitsinfo (dict): Additional step information
- Return type:
- render() ndarray | None[source]
Render the current environment state following Gymnasium convention.
- Returns:
RGB array of shape (height, width, 3) if render_mode is “rgb_array” None: If render_mode is “human” or renderer is not available
- Return type:
- save_file(file_name)[source]
Save the simulation output data to a file.
This method saves the complete simulation history to a specified file. The simulation must have been initialized with store_history=True for this method to work properly.
- Parameters:
file_name (str) – The path and filename where the output should be saved. The file format is typically NetCDF (.nc extension).
- Raises:
ctypes.ArgumentError – If the environment was created without store_history=True.
RuntimeError – If there was an error during the save operation.
Action Handling
The action handling system defines what parameters the RL agent can control and how these map to TORAX configuration updates.
Action Handler
- class gymtorax.action_handler.ActionHandler(actions: list[Action])[source]
Bases:
objectInternal container and manager for multiple actions.
This class is used internally by the gymtorax framework to manage collections of actions.
- Parameters:
actions (list[Action]) – List of Action instances to manage.
- Variables:
actions – Internal dictionary of managed actions indexed by name.
action_space – Gymnasium Dict space representing all managed actions.
number_of_updates – Counter tracking the number of action updates performed.
- __init__(actions: list[Action]) None[source]
Initialize the ActionHandler with a list of actions.
Validates action compatibility, builds the action space, and sets up internal tracking structures.
- Parameters:
actions (list[Action]) – List of Action instances to manage. Actions must have unique names and compatible configuration mappings.
- Raises:
ValueError – If duplicate action names or configuration paths are found, or if incompatible actions (e.g., both Ip and Vloop) are provided.
- Return type:
None
- get_action_variables() dict[str, list[str]][source]
Get a dictionary of state variables modified by the managed actions.
- update_actions(actions: dict[str, ndarray[tuple[int, ...], dtype[floating]]]) None[source]
Update the current values of all managed actions.
This method validates that all provided actions exist in the handler, converts values to numpy arrays with correct dtypes, validates bounds, and updates each action’s internal values using the action’s
_set_valuesmethod. The update counter is incremented after successful processing.- Parameters:
actions (dict) – Dictionary mapping action names to their new values. Keys must correspond to existing action names in this handler. Values must be numpy arrays compatible with each action expected format and bounds.
- Raises:
ValueError – If any action name in the actions dict does not exist in this handler’s managed actions.
- Return type:
None
- build_action_space() Dict[source]
Build a Gymnasium Dict action space from all managed actions.
Creates a dictionary-based action space where each key corresponds to an action’s name and each value is a Box space with the action’s bounds and data type.
- Returns:
- Action space structure.
The action space structure is a dictionnary with action names as keys and Box spaces as values. Each Box space uses the action’s min/max bounds and dtype for proper numerical handling.
- Return type:
Base Action Class
- class gymtorax.action_handler.Action(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64)[source]
Bases:
ABCAbstract base class for all TORAX simulation actions.
An action represents a controllable parameter or set of parameters that can influence plasma behavior. Each action has bounds, current values, and knows how to map itself to TORAX configuration dictionaries.
This class is designed to be extended by users to create custom actions for specific control parameters. Subclasses must define the class attributes to specify the action dimensionality, bounds, and configuration mapping.
- Class Attributes:
name (str) – Unique identifier for this action type
dimension (int) – Number of parameters controlled by this action
default_min (list[float]) – Default minimum values for parameters
default_max (list[float]) – Default maximum values for parameters
config_mapping (dict[tuple[str, …], tuple[int, float]]) – Mapping from configuration paths to parameter indices and scaling factors. Keys are tuples representing the nested path in the config dictionary, values are tuples of
(parameter_index, scaling_factor).state_var (tuple[tuple[str, …], …]) – Tuple of tuples specifying the state variables directly modified by this action. Each inner tuple contains the path to a state variable (e.g.,
('scalars', 'Ip')or('profiles', 'p_ecrh_e')).
- Variables:
ramp_rate (numpy.ndarray) – Ramp rate limits for each parameter.
numpy.infindicates no ramp rate limit for that parameter.dtype (numpy.dtype) – NumPy data type for action arrays (default:
np.float64)
- Parameters:
Example
Create a custom action for controlling two parameters:
>>> class TwoParamAction(Action): ... name = "CustomTwoParam" ... dimension = 2 ... default_min = [0.0, -5.0] ... default_max = [10.0, 5.0] ... default_ramp_rate = [0.5, None] # First param limited, second unlimited ... config_mapping = { ... ('section', 'param1'): (0, 1), # No scaling ... ('section', 'param2'): (1, 0.5) # Scale by 0.5 ... } ... state_var = {'scalars': ['param1', 'param2']} >>> action = TwoParamAction()
- __init__(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64) None[source]
Initialize an Action instance.
- Parameters:
min (list[float] | None) – Custom minimum bounds for each parameter. If
None, uses the classdefault_minvalues. Must have length equal todimension.max (list[float] | None) – Custom maximum bounds for each parameter. If
None, uses the classdefault_maxvalues. Must have length equal todimension.ramp_rate (list[float | None] | None) – Custom ramp rate limits for each parameter. If
None, uses the classdefault_ramp_ratevalues. Must have length equal todimension. Each element can beNone(no limit) or a float (max change per step).dtype (dtype) – NumPy data type for the action arrays (default:
np.float64). Used for creating action spaces.
- Raises:
ValueError – If
nameclass attribute is not definedValueError – If
dimensionclass attribute is not defined or not a positive integerValueError – If
config_mappingclass attribute is not definedValueError – If
default_min,default_max, ordefault_ramp_ratedo not match the dimensionValueError – If provided
min,max, orramp_ratedo not match the dimension
- Return type:
None
- property min: ndarray[tuple[int, ...], dtype[floating]]
Minimum bounds for this action parameters.
- Returns:
- Array of minimum values, one for each parameter
controlled by this action.
- Return type:
- property max: ndarray[tuple[int, ...], dtype[floating]]
Maximum bounds for this action parameters.
- Returns:
- Array of maximum values, one for each parameter
controlled by this action.
- Return type:
- init_dict(config_dict: dict[str, Any]) None[source]
Initialize a TORAX configuration dictionary with this action parameters.
This method sets up the configuration dictionary with the action current values at
time=0, creating the proper time-dependent parameter structure expected by TORAX.- Parameters:
config_dict (dict[str, Any]) – The TORAX configuration dictionary to initialize. Must have the nested structure expected by this action config_mapping.
- Raises:
KeyError – If the configuration dictionary does not have the expected structure for this action parameters.
RuntimeError – If any error occurs during the initialization process.
- Return type:
None
- update_to_config(config_dict: dict[str, Any], time: float) None[source]
Update a TORAX configuration dictionary with new action values.
This method updates the time-dependent parameters in the configuration dictionary with the action current values at the specified time. Scaling factors from
config_mappingare applied consistently.- Parameters:
- Return type:
None
Note
The configuration dictionary must have been initialized with
init_dictbefore calling this method. Values are scaled by the factors defined inconfig_mappingbefore being stored.
Concrete Actions
- class gymtorax.action_handler.IpAction(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64)[source]
Bases:
ActionExample action for controlling plasma current (Ip).
This action controls the plasma current parameter in TORAX simulations. It is a single-parameter action with non-negative bounds.
- Class Attributes:
name –
"Ip"dimension –
1(single parameter)default_min –
[_MIN_IP_AMPS](minimum current per TORAX requirements)default_max –
[numpy.inf]default_ramp_rate –
[None]config_mapping – Maps to
('profile_conditions', 'Ip')state_var –
{'scalars': ['Ip']}- directly modifies plasma current scalar
- Action Parameters:
0 – Plasma current (Ip) in Amperes
- Parameters:
Example
>>> ip_action = IpAction() >>> ip_action._set_values([1.5e6]) # 1.5 MA plasma current
- class gymtorax.action_handler.VloopAction(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64)[source]
Bases:
ActionExample action for controlling loop voltage at the last closed flux surface.
This action controls the loop voltage parameter (v_loop_lcfs) in TORAX simulations. It is a single-parameter action with non-negative bounds.
- Class Attributes:
name –
"V_loop"dimension –
1(single parameter)default_min –
[0.0]default_max –
[numpy.inf]default_ramp_rate –
[None]config_mapping – Maps to
('profile_conditions', 'v_loop_lcfs')state_var –
{'scalars': ['v_loop_lcfs']}- directly modifies loop voltage scalar
- Action Parameters:
0 – Loop voltage (v_loop_lcfs) in Volts
- Parameters:
Example
>>> vloop_action = VloopAction() >>> vloop_action._set_values([2.5]) # 2.5 V loop voltage
- class gymtorax.action_handler.EcrhAction(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64)[source]
Bases:
ActionExample action for controlling Electron Cyclotron Resonance Heating (ECRH).
This action controls three ECRH parameters: total power, Gaussian location, and Gaussian width of the heating profile.
- Class Attributes:
name –
"ECRH"dimension –
3(power, location, width)default_min –
[0.0, 0.0, 0.0]default_max –
[numpy.inf, numpy.inf, numpy.inf]default_ramp_rate –
[None, None, None]config_mapping – Maps to ECRH source parameters
state_var –
{'scalars': ['P_ecrh_e']}- modifies total electron-cyclotron power scalar
- Action Parameters:
0 – Total power (P_total) in Watts
1 – Gaussian location (gaussian_location) - normalized radius [0,1]
2 – Gaussian width (gaussian_width) - profile width parameter
- Parameters:
Example
>>> ecrh_action = EcrhAction() >>> ecrh_action._set_values([5e6, 0.3, 0.1]) # 5MW, r/a=0.3, width=0.1
- class gymtorax.action_handler.NbiAction(nbi_w_to_ma=1 / 16e6, **kwargs)[source]
Bases:
ActionExample action for controlling Neutral Beam Injection (NBI).
This action controls three NBI parameters: heating power, Gaussian location, and Gaussian width of the heating profile. The current drive power is automatically calculated from the heating power using a configurable conversion factor.
- Class Attributes:
name –
"NBI"dimension –
3(heating power, location, width)default_min –
[0.0, 0.0, 0.01]default_max –
[numpy.inf, 1.0, numpy.inf]default_ramp_rate –
[None, None, None]config_mapping – Maps to generic heat and current source parameters in TORAX configuration
state_var –
{'scalars': ['P_aux_generic_total']}- modifies total auxiliary power scalar
- Variables:
- Action Parameters:
0 – Heating power (generic_heat P_total) in Watts
1 – Gaussian location (shared by heat and current) - normalized radius [0,1]
2 – Gaussian width (shared by heat and current) - profile width parameter
Example
>>> nbi_action = NbiAction() >>> nbi_action._set_values([10e6, 0.4, 0.2]) # 10MW heating, r/a=0.4, width=0.2
>>> # NBI with custom conversion factor >>> nbi_custom = NbiAction(nbi_w_to_ma=1/20e6) # 20MW per 1MA >>> nbi_custom._set_values([20e6, 0.3, 0.15])
>>> # NBI with current drive disabled >>> nbi_heating_only = NbiAction(nbi_w_to_ma=0) >>> nbi_heating_only._set_values([15e6, 0.5, 0.1])
- __init__(nbi_w_to_ma=1 / 16e6, **kwargs)[source]
Initialize NbiAction with configurable heating-to-current conversion.
- Parameters:
nbi_w_to_ma – Conversion factor from heating power (Watts) to current drive (MA). Default is
1/16e6, meaning 16MW of heating produces 1MA of current drive. Set to0to disable current drive while keeping heating.**kwargs – Additional arguments passed to the parent Action class (
min,max,ramp_rate,dtype).
Example
>>> # Default conversion (16MW -> 1MA) >>> nbi = NbiAction()
>>> # Custom conversion (20MW -> 1MA) >>> nbi = NbiAction(nbi_w_to_ma=1/20e6)
>>> # Heating only, no current drive >>> nbi = NbiAction(nbi_w_to_ma=0)
Observation Handling
The observation system extracts relevant plasma state information and formats it for RL agents.
Base Observation Class
- class gymtorax.observation_handler.Observation(variables: dict[str, list[str]] | None = None, custom_bounds_filename: str | None = None, exclude: dict[str, list[str]] | None = None, dtype: dtype = np.float64)[source]
Bases:
ABCAbstract base class for building observation spaces from TORAX DataTree outputs.
Converts TORAX simulation outputs into structured observation spaces for reinforcement learning environments. Handles variable selection, bounds specification, and automatic action variable filtering.
- Variables:
variables_to_include (dict) – Variables to include in observation space.
variables_to_exclude (dict) – Variables to exclude from observation space.
custom_bounds (dict) – Custom bounds for variables.
dtype (numpy.dtype) – Data type for observation arrays.
action_variables (dict) – Variables controlled by actions.
state_variables (dict) – Available variables from TORAX output.
observation_variables (dict) – Final filtered observation variables.
bounds (dict) – Final bounds after processing.
- Parameters:
- __init__(variables: dict[str, list[str]] | None = None, custom_bounds_filename: str | None = None, exclude: dict[str, list[str]] | None = None, dtype: dtype = np.float64) None[source]
Initialize Observation handler.
Sets up configuration for the observation handler. Requires subsequent calls to set_state_variables(), set_action_variables(), and build_observation_space() before use.
- Parameters:
variables (dict[str, list[str]] | None) – Variables to include. Format:
{"profiles": [names], "scalars": [names]}. IfNone, includes all available variables except those inexclude.custom_bounds_filename (str | None) – Path to JSON file with custom bounds. Format:
{"profiles": {var: {"min": val, "max": val}}, "scalars": {...}}.exclude (dict[str, list[str]] | None) – Variables to exclude. Format:
{"profiles": [names], "scalars": [names]}. Cannot be used withvariablesparameter.dtype (dtype) – Data type for observation arrays.
- Raises:
ValueError – If both
variablesandexcludespecified or invalid configuration.- Return type:
None
- set_action_variables(variables: dict[str, list[str]]) None[source]
Set variables controlled by actions.
These variables are removed from the observation space to prevent redundancy between actions and observations.
- set_state_variables(state: DataTree) None[source]
Set available state variables from TORAX output.
Catalogs all variables from the TORAX DataTree for inclusion in the observation space.
- Parameters:
state (DataTree) – TORAX DataTree with
/profiles/and/scalars/datasets.- Return type:
None
- extract_state_observation(datatree: DataTree) tuple[dict[str, dict[str, ndarray]], dict[str, dict[str, ndarray]]][source]
Extract state and observation data from TORAX output.
Returns both complete state (all variables) and filtered observation (selected variables only).
- Parameters:
datatree (DataTree) – TORAX simulation output containing profiles and scalars datasets.
- Returns:
- state (dict): Complete state with all variables.
Format
{"profiles": {var: array}, "scalars": {var: value}}
- observation (dict): Filtered observation with selected variables.
Format
{"profiles": {var: array}, "scalars": {var: value}}
- Return type:
tuple[dict[str, dict[str, numpy.ndarray]], dict[str, dict[str, numpy.ndarray]]]
- build_observation_space() Dict[source]
Build Gymnasium observation space for selected variables.
Creates nested Dict space with Box spaces for each variable using configured bounds. Validates configuration and finalizes variable selection.
- Returns:
- Gymnasium Dict space with structure
{"profiles": {var: Box}, "scalars": {var: Box}}
- Return type:
- Raises:
ValueError – If validation fails or required setup incomplete.
Concrete Observations
- class gymtorax.observation_handler.AllObservation(exclude=None, custom_bounds_file=None)[source]
Bases:
ObservationObservation handler that includes all available TORAX variables.
Creates a complete observation space containing all profile and scalar variables available in the TORAX simulation output. Supports variable exclusions and custom bounds configuration.
Example
>>> obs = AllObservation() >>> obs = AllObservation(exclude={"profiles": ["n_impurity"]}) >>> obs = AllObservation(custom_bounds_filename="bounds.json")
- __init__(exclude=None, custom_bounds_file=None) None[source]
Initialize AllObservation with all available TORAX variables.
Creates an observation handler that includes all available TORAX variables by default, with flexible configuration through keyword arguments.
- Parameters:
- Return type:
None
Example
>>> obs = AllObservation() >>> obs = AllObservation(exclude={"profiles": ["psi"]})
Reward Functions
Reward functions define the control objectives and translate plasma performance into scalar signals for RL training.
TORAX reward module.
This module provides functions to extract specific metrics from the state dictionary returned by the TORAX simulator. These metrics can be used to construct reward functions for reinforcement learning environments focused on tokamak control and optimization. Other reward functions can be created.
- gymtorax.rewards.get_fusion_gain(state: dict) float[source]
Get the fusion gain \(Q\) from the state dictionary.
- gymtorax.rewards.get_beta_N(state: dict) float[source]
Get the normalized \(\beta_N\) from the state dictionary.
- gymtorax.rewards.get_tau_E(state: dict) float[source]
Get the energy confinement time \(\tau_E\) from the state dictionary.
- gymtorax.rewards.get_h98(state: dict) float[source]
Get the H-mode confinement quality factor from the state dictionary.
- gymtorax.rewards.get_q_profile(state: dict) ndarray[source]
Get the safety factor profile \(q\) from the state dictionary.
- Parameters:
state (dict) – The state dictionary containing profile values.
- Returns:
The safety factor profile \(q\).
- Return type:
- gymtorax.rewards.get_q_min(state: dict) float[source]
Get the minimum safety factor \(q_{min}\) from the state dictionary.