Environment Designers

The Gym-TORAX package provides a powerful framework for plasma physics specialists to explore and create new environments. It is designed as a flexible toolkit that allows users with a background in plasma physics to develop and test their own scenarios, configurations, and experiments. This includes defining new actions, observations, the TORAX configuration (for the plasma properties) and rewards.

Creation of New Environments

Base Environment

class gymtorax.envs.base_env.BaseEnv(render_mode: str | None = None, log_level: str = 'warning', log_file: str | None = None, plot_config: FigureProperties | str = 'default', store_history: bool = False)[source]

Bases: Env, ABC

Abstract base class for TORAX plasma simulation environments.

This class integrates TORAX physics simulations with the Gymnasium reinforcement learning framework, providing a standardized interface for plasma control tasks. It handles the complexities of time discretization, simulation management, and action/observation space construction.

The environment operates by:

  1. Setting up logging configuration for debugging and monitoring

  2. Initializing TORAX configuration and simulation state

  3. Managing discrete time steps with configurable time intervals

  4. Applying actions by updating TORAX configuration parameters

  5. Executing simulation steps and extracting observations

  6. Computing rewards and determining episode termination

Variables:
  • observation_handler (Observation) – Handles observation space and data extraction

  • action_handler (ActionHandler) – Manages action space and parameter updates

  • config (ConfigLoader) – TORAX configuration manager

  • torax_app (ToraxApp) – TORAX simulation wrapper

  • state (dict) – Current complete plasma state

  • observation (dict) – Current filtered observation

  • T (float) – Total simulation time [s]

  • delta_t_a (float) – Time interval between actions [s]

  • current_time (float) – Current simulation time [s]

  • timestep (int) – Current timestep counter

  • terminated (bool) – Episode termination flag

  • truncated (bool) – Episode truncation flag

Parameters:
  • render_mode (str | None)

  • log_level (str)

  • log_file (str | None)

  • plot_config (FigureProperties | str)

  • store_history (bool)

Note

Subclasses must implement these abstract methods:

  • _define_observation_space: Define observation space variables

  • _define_action_space: Define available control actions

  • _get_torax_config: Define TORAX configuration parameters

  • _compute_reward: Define reward signal (optional override)

__init__(render_mode: str | None = None, log_level: str = 'warning', log_file: str | None = None, plot_config: FigureProperties | str = 'default', store_history: bool = False) None[source]

Initialize the TORAX gymnasium environment.

This method sets up the complete simulation environment including TORAX configuration, action/observation spaces, time discretization, and rendering components.

Parameters:
  • render_mode (str or None) – Rendering mode for visualization. Options: "human", "rgb_array", or None. Defaults to None.

  • log_level (str) – Logging level for environment operations. Options: "debug", "info", "warning", "error", "critical". Defaults to "warning".

  • log_file (str or None) – Path to log file for writing log messages. If None, logs to console. Defaults to None.

  • plot_config (str or FigureProperties) – Name of the plot configuration to use (e.g., "default"). Can also be a torax FigureProperties instance for custom plot configuration.

  • store_history (bool) – Whether to store simulation history for later saving. Set to True if you plan to use save_file method. Defaults to False.

Raises:
  • ValueError – If required parameters are missing for chosen discretization method.

  • TypeError – If discretization_torax is not "auto" or "fixed".

  • KeyError – If required keys are missing from TORAX configuration.

Return type:

None

Note

Subclasses should use **kwargs to pass parameters to avoid explicit parameter listing and maintain flexibility as the base class evolves. Environment-specific defaults can be set using kwargs.setdefault() before calling super().__init__().

The environment must implement the abstract methods _define_observation_space, _define_action_space, _get_torax_config, and _compute_reward.

reset(*, seed: int | None = None, options: dict[str, Any] | None = None) tuple[dict[str, Any], dict[str, Any]][source]

Reset the environment to its initial state for a new episode.

This method initializes a new simulation episode by:

  1. Resetting internal counters and flags

  2. Starting the TORAX simulation from initial conditions

  3. Extracting the initial observation state

  4. Optionally rendering the initial state

Parameters:
  • seed (int or None) – Random seed for reproducible episode initialization. Used to seed the environment’s random number generator for deterministic behavior across resets. If None, no seeding is performed. Defaults to None.

  • options (dict[str, Any] or None) – Additional options for environment reset. Currently unused but maintained for Gymnasium compatibility. Defaults to None.

Returns:

  • observation (dict): Initial observation of plasma state

  • info (dict): Additional information (empty dict)

Return type:

tuple[dict[str, Any], dict[str, Any]]

step(action: dict[str, ndarray[tuple[int, ...], dtype[floating]]]) tuple[dict[str, Any], float, bool, bool, dict[str, Any]][source]

Execute one environment step with the given action.

This method implements the core RL interaction by:

  1. Capturing the current state before action

  2. Applying the action to update TORAX configuration

  3. Running the simulation for one time interval

  4. Extracting the new observation state

  5. Computing the reward signal

  6. Checking for episode termination

  7. Updating time counters

Parameters:

action (dict[str, numpy.ndarray]) – Action dictionary containing parameter values for all configured actions.

Returns:

  • observation (dict): New plasma state observation

  • reward (float): Reward signal for this step

  • terminated (bool): True if episode ended due to terminal condition

  • truncated (bool): True if episode ended due to time/step limits

  • info (dict): Additional step information

Return type:

tuple[dict[str, Any], float, bool, bool, dict[str, Any]]

close() None[source]

Clean up environment resources.

Return type:

None

render() ndarray | None[source]

Render the current environment state following Gymnasium convention.

Returns:

RGB array of shape (height, width, 3) if render_mode is “rgb_array” None: If render_mode is “human” or renderer is not available

Return type:

numpy.ndarray

save_file(file_name)[source]

Save the simulation output data to a file.

This method saves the complete simulation history to a specified file. The simulation must have been initialized with store_history=True for this method to work properly.

Parameters:

file_name (str) – The path and filename where the output should be saved. The file format is typically NetCDF (.nc extension).

Raises:

Abstract Methods

abstract BaseEnv._define_action_space() list[Action][source]

Define the available control actions for this environment.

This method must be implemented by concrete subclasses to specify which plasma parameters can be controlled by the RL agent.

Returns:

List of Action instances representing controllable

parameters with their bounds and TORAX configuration mappings.

Return type:

list[Action]

Example

>>> def _define_action_space(self):
...     return [
...         IpAction(min=[0.5e6], max=[2.0e6]),      # Plasma current
...         EcrhAction(                               # ECRH heating
...             min=[0.0, 0.0, 0.0],                # [power, loc, width]
...             max=[10e6, 1.0, 0.5]
...         ),
...         NbiAction()                               # NBI with defaults
...     ]
abstract BaseEnv._define_observation_space() Observation[source]

Define the observation space variables for this environment.

This method must be implemented by concrete subclasses to specify which TORAX variables should be included in the observation space.

Returns:

Configured observation handler that defines which

plasma state variables are visible to the RL agent.

Return type:

Observation

Example

>>> def _define_observation_space(self):
...     return AllObservation(
...         exclude=["n_impurity", "Z_impurity"],
...         custom_bounds={
...             "T_e": (0.0, 50.0),  # Temperature range in keV
...             "T_i": (0.0, 50.0)
...         }
...     )
abstract BaseEnv._get_torax_config() dict[str, Any][source]

Define the TORAX simulation configuration.

This abstract method must be implemented by concrete subclasses which provides the necessary parameters for the TORAX simulation, including its core configuration, the time discretization method, the control time step, and the ratio between simulation and control time steps.

Returns:

A dictionary containing the TORAX configuration.

The dictionary must have the following keys:

  • "config" (dict): A dictionary of TORAX configuration parameters.

  • "discretization" (str): The time discretization method. Options are "auto" (uses 'delta_t_a') or "fixed" (uses 'ratio_a_sim').

  • "ratio_a_sim" (int or None): The ratio of action timesteps to simulation timesteps. Required if 'discretization' is "fixed".

  • "delta_t_a" (float or None): The time interval between actions in seconds. Required if 'discretization' is "auto".

Return type:

dict[str, Any]

Example

>>> def _get_torax_config(self):
...     return {
...         "config": TORAX_CONFIG,
...         "discretization": "auto",
...         "delta_t_a": 0.05,  # 50 ms between actions
...         # "ratio_a_sim": 10, # Only needed if using "fixed" discretization
...     }
abstract BaseEnv._compute_reward(state: dict[str, Any], next_state: dict[str, Any], action: dict[str, ndarray[tuple[int, ...], dtype[floating]]]) float[source]

Define the reward signal for a state transition.

This method should be overridden by concrete subclasses to implement task-specific reward functions. The default implementation returns 0.0.

Parameters:
  • state (dict[str, Any]) – Previous plasma state before action was applied. Contains complete state with "profiles" and "scalars" dictionaries.

  • next_state (dict[str, Any]) – New plasma state after action and simulation step. Same structure as state parameter.

  • action (dict[str, numpy.ndarray]) – Action dictionary that was applied to cause this transition.

Returns:

Reward value for this state transition.

Return type:

float

Example

>>> def _compute_reward(self, state, next_state, action):
...     # Reward based on proximity to target beta_N
...     target_beta = 2.0
...     current_beta = next_state["scalars"]["beta_N"]
...     return -abs(current_beta - target_beta)

Here is a simple example of how to create a new environment by extending the base class:

from gymtorax.envs.base_env import BaseEnv
import gymtorax.action_handler as ah
import gymtorax.observation_handler as oh
import gymtorax.rewards as rw

class CustomEnv(BaseEnv):
    def _define_action_space(self):
        actions = [ah.IpAction(),]
        return actions

    def _define_observation_space(self):
        return oh.AllObservation()

    def _get_torax_config(self):
        return {"config": CONFIG,
            "discretization": "auto",
            "delta_t_a": 1.0}

    def _compute_reward(self, current_state, next_state, action):
        Q = rw.get_fusion_gain(next_state)
        q_min = rw.get_q_min(next_state)
        w_Q, w_qmin = 1.0, 1.0

        def q_min_function():
        if q_min <= 1:
            return 0
        elif q_min > 1:
            return 1

        return w_Q * Q + w_qmin * q_min_function()

Creation of Actions

class gymtorax.action_handler.Action(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64)[source]

Bases: ABC

Abstract base class for all TORAX simulation actions.

An action represents a controllable parameter or set of parameters that can influence plasma behavior. Each action has bounds, current values, and knows how to map itself to TORAX configuration dictionaries.

This class is designed to be extended by users to create custom actions for specific control parameters. Subclasses must define the class attributes to specify the action dimensionality, bounds, and configuration mapping.

Class Attributes:
  • name (str) – Unique identifier for this action type

  • dimension (int) – Number of parameters controlled by this action

  • default_min (list[float]) – Default minimum values for parameters

  • default_max (list[float]) – Default maximum values for parameters

  • config_mapping (dict[tuple[str, …], tuple[int, float]]) – Mapping from configuration paths to parameter indices and scaling factors. Keys are tuples representing the nested path in the config dictionary, values are tuples of (parameter_index, scaling_factor).

  • state_var (tuple[tuple[str, …], …]) – Tuple of tuples specifying the state variables directly modified by this action. Each inner tuple contains the path to a state variable (e.g., ('scalars', 'Ip') or ('profiles', 'p_ecrh_e')).

Variables:
  • values (list[float]) – Current parameter values

  • ramp_rate (numpy.ndarray) – Ramp rate limits for each parameter. numpy.inf indicates no ramp rate limit for that parameter.

  • dtype (numpy.dtype) – NumPy data type for action arrays (default: np.float64)

Parameters:

Example

Create a custom action for controlling two parameters:

>>> class TwoParamAction(Action):
...     name = "CustomTwoParam"
...     dimension = 2
...     default_min = [0.0, -5.0]
...     default_max = [10.0, 5.0]
...     default_ramp_rate = [0.5, None]  # First param limited, second unlimited
...     config_mapping = {
...         ('section', 'param1'): (0, 1),      # No scaling
...         ('section', 'param2'): (1, 0.5)     # Scale by 0.5
...     }
...     state_var = {'scalars': ['param1', 'param2']}
>>> action = TwoParamAction()
dimension: int
name: str
default_min: list[float]
default_max: list[float]
default_ramp_rate: list[float | None]
config_mapping: dict[tuple[str, ...], int]
state_var: dict[str, list[str]] = {}
__init__(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64) None[source]

Initialize an Action instance.

Parameters:
  • min (list[float] | None) – Custom minimum bounds for each parameter. If None, uses the class default_min values. Must have length equal to dimension.

  • max (list[float] | None) – Custom maximum bounds for each parameter. If None, uses the class default_max values. Must have length equal to dimension.

  • ramp_rate (list[float | None] | None) – Custom ramp rate limits for each parameter. If None, uses the class default_ramp_rate values. Must have length equal to dimension. Each element can be None (no limit) or a float (max change per step).

  • dtype (dtype) – NumPy data type for the action arrays (default: np.float64). Used for creating action spaces.

Raises:
  • ValueError – If name class attribute is not defined

  • ValueError – If dimension class attribute is not defined or not a positive integer

  • ValueError – If config_mapping class attribute is not defined

  • ValueError – If default_min, default_max, or default_ramp_rate do not match the dimension

  • ValueError – If provided min, max, or ramp_rate do not match the dimension

Return type:

None

property min: ndarray[tuple[int, ...], dtype[floating]]

Minimum bounds for this action parameters.

Returns:

Array of minimum values, one for each parameter

controlled by this action.

Return type:

numpy.ndarray

property max: ndarray[tuple[int, ...], dtype[floating]]

Maximum bounds for this action parameters.

Returns:

Array of maximum values, one for each parameter

controlled by this action.

Return type:

numpy.ndarray

init_dict(config_dict: dict[str, Any]) None[source]

Initialize a TORAX configuration dictionary with this action parameters.

This method sets up the configuration dictionary with the action current values at time=0, creating the proper time-dependent parameter structure expected by TORAX.

Parameters:

config_dict (dict[str, Any]) – The TORAX configuration dictionary to initialize. Must have the nested structure expected by this action config_mapping.

Raises:
  • KeyError – If the configuration dictionary does not have the expected structure for this action parameters.

  • RuntimeError – If any error occurs during the initialization process.

Return type:

None

update_to_config(config_dict: dict[str, Any], time: float) None[source]

Update a TORAX configuration dictionary with new action values.

This method updates the time-dependent parameters in the configuration dictionary with the action current values at the specified time. Scaling factors from config_mapping are applied consistently.

Parameters:
  • config_dict (dict[str, Any]) – The TORAX configuration dictionary to update. Must have been previously initialized with init_dict.

  • time (float) – Simulation time for this update. Must be > 0.

Return type:

None

Note

The configuration dictionary must have been initialized with init_dict before calling this method. Values are scaled by the factors defined in config_mapping before being stored.

get_mapping() dict[tuple[str, ...], int][source]

Get the mapping of configuration dictionary paths to action indices and factors.

Returns:

Mapping of config dictionary paths

to tuples of (action_parameter_index, scaling_factor).

Return type:

dict[tuple[str, …], tuple[int, float]]

__repr__() str[source]

Return a string representation of the action.

Returns:

String showing the action class name, current values, and bounds.

Return type:

str

Creation of Observations

class gymtorax.observation_handler.Observation(variables: dict[str, list[str]] | None = None, custom_bounds_filename: str | None = None, exclude: dict[str, list[str]] | None = None, dtype: dtype = np.float64)[source]

Bases: ABC

Abstract base class for building observation spaces from TORAX DataTree outputs.

Converts TORAX simulation outputs into structured observation spaces for reinforcement learning environments. Handles variable selection, bounds specification, and automatic action variable filtering.

Variables:
  • variables_to_include (dict) – Variables to include in observation space.

  • variables_to_exclude (dict) – Variables to exclude from observation space.

  • custom_bounds (dict) – Custom bounds for variables.

  • dtype (numpy.dtype) – Data type for observation arrays.

  • action_variables (dict) – Variables controlled by actions.

  • state_variables (dict) – Available variables from TORAX output.

  • observation_variables (dict) – Final filtered observation variables.

  • bounds (dict) – Final bounds after processing.

Parameters:
__init__(variables: dict[str, list[str]] | None = None, custom_bounds_filename: str | None = None, exclude: dict[str, list[str]] | None = None, dtype: dtype = np.float64) None[source]

Initialize Observation handler.

Sets up configuration for the observation handler. Requires subsequent calls to set_state_variables(), set_action_variables(), and build_observation_space() before use.

Parameters:
  • variables (dict[str, list[str]] | None) – Variables to include. Format: {"profiles": [names], "scalars": [names]}. If None, includes all available variables except those in exclude.

  • custom_bounds_filename (str | None) – Path to JSON file with custom bounds. Format: {"profiles": {var: {"min": val, "max": val}}, "scalars": {...}}.

  • exclude (dict[str, list[str]] | None) – Variables to exclude. Format: {"profiles": [names], "scalars": [names]}. Cannot be used with variables parameter.

  • dtype (dtype) – Data type for observation arrays.

Raises:

ValueError – If both variables and exclude specified or invalid configuration.

Return type:

None

set_action_variables(variables: dict[str, list[str]]) None[source]

Set variables controlled by actions.

These variables are removed from the observation space to prevent redundancy between actions and observations.

Parameters:

variables (dict[str, list[str]]) – Action variables by category. Format: {"profiles": [names], "scalars": [names]}.

Return type:

None

set_state_variables(state: DataTree) None[source]

Set available state variables from TORAX output.

Catalogs all variables from the TORAX DataTree for inclusion in the observation space.

Parameters:

state (DataTree) – TORAX DataTree with /profiles/ and /scalars/ datasets.

Return type:

None

extract_state_observation(datatree: DataTree) tuple[dict[str, dict[str, ndarray]], dict[str, dict[str, ndarray]]][source]

Extract state and observation data from TORAX output.

Returns both complete state (all variables) and filtered observation (selected variables only).

Parameters:

datatree (DataTree) – TORAX simulation output containing profiles and scalars datasets.

Returns:

  • state (dict): Complete state with all variables.

    Format {"profiles": {var: array}, "scalars": {var: value}}

  • observation (dict): Filtered observation with selected variables.

    Format {"profiles": {var: array}, "scalars": {var: value}}

Return type:

tuple[dict[str, dict[str, numpy.ndarray]], dict[str, dict[str, numpy.ndarray]]]

build_observation_space() Dict[source]

Build Gymnasium observation space for selected variables.

Creates nested Dict space with Box spaces for each variable using configured bounds. Validates configuration and finalizes variable selection.

Returns:

Gymnasium Dict space with structure

{"profiles": {var: Box}, "scalars": {var: Box}}

Return type:

gymnasium.spaces.Dict

Raises:

ValueError – If validation fails or required setup incomplete.

Creation of Configurations

The configuration needed in the method _define_torax_config is a dictionary which is exactly the same as the one used in TORAX. You can find more details about the configuration in the TORAX documentation. The configuration used in our example is available here.

Creation of Rewards

TORAX reward module.

This module provides functions to extract specific metrics from the state dictionary returned by the TORAX simulator. These metrics can be used to construct reward functions for reinforcement learning environments focused on tokamak control and optimization. Other reward functions can be created.

gymtorax.rewards.get_fusion_gain(state: dict) float[source]

Get the fusion gain \(Q\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing scalar values.

Returns:

The fusion gain \(Q\).

Return type:

float

gymtorax.rewards.get_beta_N(state: dict) float[source]

Get the normalized \(\beta_N\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing scalar values.

Returns:

The normalized \(\beta_N\).

Return type:

float

gymtorax.rewards.get_tau_E(state: dict) float[source]

Get the energy confinement time \(\tau_E\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing scalar values.

Returns:

The energy confinement time \(\tau_E\).

Return type:

float

gymtorax.rewards.get_h98(state: dict) float[source]

Get the H-mode confinement quality factor from the state dictionary.

Parameters:

state (dict) – The state dictionary containing scalar values.

Returns:

The \(H98\) factor.

Return type:

float

gymtorax.rewards.get_q_profile(state: dict) ndarray[source]

Get the safety factor profile \(q\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing profile values.

Returns:

The safety factor profile \(q\).

Return type:

numpy.ndarray

gymtorax.rewards.get_q_min(state: dict) float[source]

Get the minimum safety factor \(q_{min}\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing scalar values.

Returns:

The minimum safety factor \(q_{min}\).

Return type:

float

gymtorax.rewards.get_q95(state: dict) float[source]

Get safety factor at 95% of the normalized poloidal flux coordinate.

Parameters:

state (dict) – The state dictionary containing profile values.

Returns:

The safety factor at 95% of the normalized poloidal flux coordinate.

Return type:

float

gymtorax.rewards.get_s_profile(state: dict) ndarray[source]

Get the magnetic shear profile \(s\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing profile values.

Returns:

The magnetic shear profile \(s\).

Return type:

numpy.ndarray