Environment Design

This section describes the core environment components that make up GymTorax: environments, actions, observations, and rewards. These are the building blocks for creating plasma control tasks.

Base Environment

TORAX Base Environment Module.

This module provides the abstract base class for TORAX plasma simulation environments compatible with the Gymnasium reinforcement learning framework. It integrates TORAX physics simulations with RL interfaces, handling time discretization, action/observation spaces, and the simulation lifecycle.

The BaseEnv class serves as a foundation for creating specific plasma control tasks by:

  • Managing TORAX configuration and simulation execution

  • Defining action and observation space structures

  • Handling time discretization and episode management

  • Providing hooks for custom reward functions and terminal conditions

  • Configurable logging system for debugging and monitoring

Classes:

BaseEnv: Abstract base class for TORAX Gymnasium environments

Example

Create a custom environment by extending BaseEnv:

>>> class PlasmaControlEnv(BaseEnv):
...     def __init__(self, render_mode=None, ``**kwargs``):
...         # Set environment-specific defaults
...         kwargs.setdefault("log_level", "info")
...         super().__init__(render_mode=render_mode, ``**kwargs``)
...
...     def _define_observation_space(self):
...         return AllObservation(exclude=["n_impurity"])
...
...     def _define_action_space(self):
...         return [IpAction(), EcrhAction()]
...
...     def _get_torax_config(self):
...         return CONFIG
...
...     def _compute_reward(self, state, next_state, action):
...         # Custom reward logic
...         return -abs(next_state["scalars"]["beta_N"] - 2.0)
class gymtorax.envs.base_env.BaseEnv(render_mode: str | None = None, log_level: str = 'warning', log_file: str | None = None, plot_config: FigureProperties | str = 'default', store_history: bool = False)[source]

Bases: Env, ABC

Abstract base class for TORAX plasma simulation environments.

This class integrates TORAX physics simulations with the Gymnasium reinforcement learning framework, providing a standardized interface for plasma control tasks. It handles the complexities of time discretization, simulation management, and action/observation space construction.

The environment operates by:

  1. Setting up logging configuration for debugging and monitoring

  2. Initializing TORAX configuration and simulation state

  3. Managing discrete time steps with configurable time intervals

  4. Applying actions by updating TORAX configuration parameters

  5. Executing simulation steps and extracting observations

  6. Computing rewards and determining episode termination

Variables:
  • observation_handler (Observation) – Handles observation space and data extraction

  • action_handler (ActionHandler) – Manages action space and parameter updates

  • config (ConfigLoader) – TORAX configuration manager

  • torax_app (ToraxApp) – TORAX simulation wrapper

  • state (dict) – Current complete plasma state

  • observation (dict) – Current filtered observation

  • T (float) – Total simulation time [s]

  • delta_t_a (float) – Time interval between actions [s]

  • current_time (float) – Current simulation time [s]

  • timestep (int) – Current timestep counter

  • terminated (bool) – Episode termination flag

  • truncated (bool) – Episode truncation flag

Parameters:
  • render_mode (str | None)

  • log_level (str)

  • log_file (str | None)

  • plot_config (FigureProperties | str)

  • store_history (bool)

Note

Subclasses must implement these abstract methods:

  • _define_observation_space: Define observation space variables

  • _define_action_space: Define available control actions

  • _get_torax_config: Define TORAX configuration parameters

  • _compute_reward: Define reward signal (optional override)

__init__(render_mode: str | None = None, log_level: str = 'warning', log_file: str | None = None, plot_config: FigureProperties | str = 'default', store_history: bool = False) None[source]

Initialize the TORAX gymnasium environment.

This method sets up the complete simulation environment including TORAX configuration, action/observation spaces, time discretization, and rendering components.

Parameters:
  • render_mode (str or None) – Rendering mode for visualization. Options: "human", "rgb_array", or None. Defaults to None.

  • log_level (str) – Logging level for environment operations. Options: "debug", "info", "warning", "error", "critical". Defaults to "warning".

  • log_file (str or None) – Path to log file for writing log messages. If None, logs to console. Defaults to None.

  • plot_config (str or FigureProperties) – Name of the plot configuration to use (e.g., "default"). Can also be a torax FigureProperties instance for custom plot configuration.

  • store_history (bool) – Whether to store simulation history for later saving. Set to True if you plan to use save_file method. Defaults to False.

Raises:
  • ValueError – If required parameters are missing for chosen discretization method.

  • TypeError – If discretization_torax is not "auto" or "fixed".

  • KeyError – If required keys are missing from TORAX configuration.

Return type:

None

Note

Subclasses should use **kwargs to pass parameters to avoid explicit parameter listing and maintain flexibility as the base class evolves. Environment-specific defaults can be set using kwargs.setdefault() before calling super().__init__().

The environment must implement the abstract methods _define_observation_space, _define_action_space, _get_torax_config, and _compute_reward.

reset(*, seed: int | None = None, options: dict[str, Any] | None = None) tuple[dict[str, Any], dict[str, Any]][source]

Reset the environment to its initial state for a new episode.

This method initializes a new simulation episode by:

  1. Resetting internal counters and flags

  2. Starting the TORAX simulation from initial conditions

  3. Extracting the initial observation state

  4. Optionally rendering the initial state

Parameters:
  • seed (int or None) – Random seed for reproducible episode initialization. Used to seed the environment’s random number generator for deterministic behavior across resets. If None, no seeding is performed. Defaults to None.

  • options (dict[str, Any] or None) – Additional options for environment reset. Currently unused but maintained for Gymnasium compatibility. Defaults to None.

Returns:

  • observation (dict): Initial observation of plasma state

  • info (dict): Additional information (empty dict)

Return type:

tuple[dict[str, Any], dict[str, Any]]

step(action: dict[str, ndarray[tuple[int, ...], dtype[floating]]]) tuple[dict[str, Any], float, bool, bool, dict[str, Any]][source]

Execute one environment step with the given action.

This method implements the core RL interaction by:

  1. Capturing the current state before action

  2. Applying the action to update TORAX configuration

  3. Running the simulation for one time interval

  4. Extracting the new observation state

  5. Computing the reward signal

  6. Checking for episode termination

  7. Updating time counters

Parameters:

action (dict[str, numpy.ndarray]) – Action dictionary containing parameter values for all configured actions.

Returns:

  • observation (dict): New plasma state observation

  • reward (float): Reward signal for this step

  • terminated (bool): True if episode ended due to terminal condition

  • truncated (bool): True if episode ended due to time/step limits

  • info (dict): Additional step information

Return type:

tuple[dict[str, Any], float, bool, bool, dict[str, Any]]

close() None[source]

Clean up environment resources.

Return type:

None

render() ndarray | None[source]

Render the current environment state following Gymnasium convention.

Returns:

RGB array of shape (height, width, 3) if render_mode is “rgb_array” None: If render_mode is “human” or renderer is not available

Return type:

numpy.ndarray

save_file(file_name)[source]

Save the simulation output data to a file.

This method saves the complete simulation history to a specified file. The simulation must have been initialized with store_history=True for this method to work properly.

Parameters:

file_name (str) – The path and filename where the output should be saved. The file format is typically NetCDF (.nc extension).

Raises:

Action Handling

The action handling system defines what parameters the RL agent can control and how these map to TORAX configuration updates.

Action Handler

class gymtorax.action_handler.ActionHandler(actions: list[Action])[source]

Bases: object

Internal container and manager for multiple actions.

This class is used internally by the gymtorax framework to manage collections of actions.

Parameters:

actions (list[Action]) – List of Action instances to manage.

Variables:
  • actions – Internal dictionary of managed actions indexed by name.

  • action_space – Gymnasium Dict space representing all managed actions.

  • number_of_updates – Counter tracking the number of action updates performed.

__init__(actions: list[Action]) None[source]

Initialize the ActionHandler with a list of actions.

Validates action compatibility, builds the action space, and sets up internal tracking structures.

Parameters:

actions (list[Action]) – List of Action instances to manage. Actions must have unique names and compatible configuration mappings.

Raises:

ValueError – If duplicate action names or configuration paths are found, or if incompatible actions (e.g., both Ip and Vloop) are provided.

Return type:

None

get_actions() dict[str, Action][source]

Get the dictionary of managed actions.

Returns:

Dictionary mapping action names to Action instances

managed by this handler.

Return type:

dict[str, Action]

get_action_variables() dict[str, list[str]][source]

Get a dictionary of state variables modified by the managed actions.

Returns:

Dictionary mapping variables categories to lists of

modified state variable names.

Return type:

dict[str, list[str]]

update_actions(actions: dict[str, ndarray[tuple[int, ...], dtype[floating]]]) None[source]

Update the current values of all managed actions.

This method validates that all provided actions exist in the handler, converts values to numpy arrays with correct dtypes, validates bounds, and updates each action’s internal values using the action’s _set_values method. The update counter is incremented after successful processing.

Parameters:

actions (dict) – Dictionary mapping action names to their new values. Keys must correspond to existing action names in this handler. Values must be numpy arrays compatible with each action expected format and bounds.

Raises:

ValueError – If any action name in the actions dict does not exist in this handler’s managed actions.

Return type:

None

build_action_space() Dict[source]

Build a Gymnasium Dict action space from all managed actions.

Creates a dictionary-based action space where each key corresponds to an action’s name and each value is a Box space with the action’s bounds and data type.

Returns:

Action space structure.

The action space structure is a dictionnary with action names as keys and Box spaces as values. Each Box space uses the action’s min/max bounds and dtype for proper numerical handling.

Return type:

gymnasium.spaces.Dict

Base Action Class

class gymtorax.action_handler.Action(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64)[source]

Bases: ABC

Abstract base class for all TORAX simulation actions.

An action represents a controllable parameter or set of parameters that can influence plasma behavior. Each action has bounds, current values, and knows how to map itself to TORAX configuration dictionaries.

This class is designed to be extended by users to create custom actions for specific control parameters. Subclasses must define the class attributes to specify the action dimensionality, bounds, and configuration mapping.

Class Attributes:
  • name (str) – Unique identifier for this action type

  • dimension (int) – Number of parameters controlled by this action

  • default_min (list[float]) – Default minimum values for parameters

  • default_max (list[float]) – Default maximum values for parameters

  • config_mapping (dict[tuple[str, …], tuple[int, float]]) – Mapping from configuration paths to parameter indices and scaling factors. Keys are tuples representing the nested path in the config dictionary, values are tuples of (parameter_index, scaling_factor).

  • state_var (tuple[tuple[str, …], …]) – Tuple of tuples specifying the state variables directly modified by this action. Each inner tuple contains the path to a state variable (e.g., ('scalars', 'Ip') or ('profiles', 'p_ecrh_e')).

Variables:
  • values (list[float]) – Current parameter values

  • ramp_rate (numpy.ndarray) – Ramp rate limits for each parameter. numpy.inf indicates no ramp rate limit for that parameter.

  • dtype (numpy.dtype) – NumPy data type for action arrays (default: np.float64)

Parameters:

Example

Create a custom action for controlling two parameters:

>>> class TwoParamAction(Action):
...     name = "CustomTwoParam"
...     dimension = 2
...     default_min = [0.0, -5.0]
...     default_max = [10.0, 5.0]
...     default_ramp_rate = [0.5, None]  # First param limited, second unlimited
...     config_mapping = {
...         ('section', 'param1'): (0, 1),      # No scaling
...         ('section', 'param2'): (1, 0.5)     # Scale by 0.5
...     }
...     state_var = {'scalars': ['param1', 'param2']}
>>> action = TwoParamAction()
dimension: int
name: str
default_min: list[float]
default_max: list[float]
default_ramp_rate: list[float | None]
config_mapping: dict[tuple[str, ...], int]
state_var: dict[str, list[str]] = {}
__init__(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64) None[source]

Initialize an Action instance.

Parameters:
  • min (list[float] | None) – Custom minimum bounds for each parameter. If None, uses the class default_min values. Must have length equal to dimension.

  • max (list[float] | None) – Custom maximum bounds for each parameter. If None, uses the class default_max values. Must have length equal to dimension.

  • ramp_rate (list[float | None] | None) – Custom ramp rate limits for each parameter. If None, uses the class default_ramp_rate values. Must have length equal to dimension. Each element can be None (no limit) or a float (max change per step).

  • dtype (dtype) – NumPy data type for the action arrays (default: np.float64). Used for creating action spaces.

Raises:
  • ValueError – If name class attribute is not defined

  • ValueError – If dimension class attribute is not defined or not a positive integer

  • ValueError – If config_mapping class attribute is not defined

  • ValueError – If default_min, default_max, or default_ramp_rate do not match the dimension

  • ValueError – If provided min, max, or ramp_rate do not match the dimension

Return type:

None

property min: ndarray[tuple[int, ...], dtype[floating]]

Minimum bounds for this action parameters.

Returns:

Array of minimum values, one for each parameter

controlled by this action.

Return type:

numpy.ndarray

property max: ndarray[tuple[int, ...], dtype[floating]]

Maximum bounds for this action parameters.

Returns:

Array of maximum values, one for each parameter

controlled by this action.

Return type:

numpy.ndarray

init_dict(config_dict: dict[str, Any]) None[source]

Initialize a TORAX configuration dictionary with this action parameters.

This method sets up the configuration dictionary with the action current values at time=0, creating the proper time-dependent parameter structure expected by TORAX.

Parameters:

config_dict (dict[str, Any]) – The TORAX configuration dictionary to initialize. Must have the nested structure expected by this action config_mapping.

Raises:
  • KeyError – If the configuration dictionary does not have the expected structure for this action parameters.

  • RuntimeError – If any error occurs during the initialization process.

Return type:

None

update_to_config(config_dict: dict[str, Any], time: float) None[source]

Update a TORAX configuration dictionary with new action values.

This method updates the time-dependent parameters in the configuration dictionary with the action current values at the specified time. Scaling factors from config_mapping are applied consistently.

Parameters:
  • config_dict (dict[str, Any]) – The TORAX configuration dictionary to update. Must have been previously initialized with init_dict.

  • time (float) – Simulation time for this update. Must be > 0.

Return type:

None

Note

The configuration dictionary must have been initialized with init_dict before calling this method. Values are scaled by the factors defined in config_mapping before being stored.

get_mapping() dict[tuple[str, ...], int][source]

Get the mapping of configuration dictionary paths to action indices and factors.

Returns:

Mapping of config dictionary paths

to tuples of (action_parameter_index, scaling_factor).

Return type:

dict[tuple[str, …], tuple[int, float]]

__repr__() str[source]

Return a string representation of the action.

Returns:

String showing the action class name, current values, and bounds.

Return type:

str

Concrete Actions

class gymtorax.action_handler.IpAction(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64)[source]

Bases: Action

Example action for controlling plasma current (Ip).

This action controls the plasma current parameter in TORAX simulations. It is a single-parameter action with non-negative bounds.

Class Attributes:
  • name"Ip"

  • dimension1 (single parameter)

  • default_min[_MIN_IP_AMPS] (minimum current per TORAX requirements)

  • default_max[numpy.inf]

  • default_ramp_rate[None]

  • config_mapping – Maps to ('profile_conditions', 'Ip')

  • state_var{'scalars': ['Ip']} - directly modifies plasma current scalar

Action Parameters:

0 – Plasma current (Ip) in Amperes

Parameters:

Example

>>> ip_action = IpAction()
>>> ip_action._set_values([1.5e6])  # 1.5 MA plasma current
name: str = 'Ip'
dimension: int = 1
default_min: list[float] = [torax._src.config.profile_conditions._MIN_IP_AMPS]
default_max: list[float] = [inf]
default_ramp_rate: list[float | None] = [None]
config_mapping: dict[tuple[str, ...], int] = {('profile_conditions', 'Ip'): (0, 1)}
state_var: dict[str, list[str]] = {'scalars': ['Ip']}
class gymtorax.action_handler.VloopAction(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64)[source]

Bases: Action

Example action for controlling loop voltage at the last closed flux surface.

This action controls the loop voltage parameter (v_loop_lcfs) in TORAX simulations. It is a single-parameter action with non-negative bounds.

Class Attributes:
  • name"V_loop"

  • dimension1 (single parameter)

  • default_min[0.0]

  • default_max[numpy.inf]

  • default_ramp_rate[None]

  • config_mapping – Maps to ('profile_conditions', 'v_loop_lcfs')

  • state_var{'scalars': ['v_loop_lcfs']} - directly modifies loop voltage scalar

Action Parameters:

0 – Loop voltage (v_loop_lcfs) in Volts

Parameters:

Example

>>> vloop_action = VloopAction()
>>> vloop_action._set_values([2.5])  # 2.5 V loop voltage
name: str = 'V_loop'
dimension: int = 1
default_min: list[float] = [0.0]
default_max: list[float] = [inf]
default_ramp_rate: list[float | None] = [None]
config_mapping: dict[tuple[str, ...], int] = {('profile_conditions', 'v_loop_lcfs'): (0, 1)}
state_var: dict[str, list[str]] = {'scalars': ['v_loop_lcfs']}
class gymtorax.action_handler.EcrhAction(min: list[float] | None = None, max: list[float] | None = None, ramp_rate: list[float | None] | None = None, dtype: dtype = np.float64)[source]

Bases: Action

Example action for controlling Electron Cyclotron Resonance Heating (ECRH).

This action controls three ECRH parameters: total power, Gaussian location, and Gaussian width of the heating profile.

Class Attributes:
  • name"ECRH"

  • dimension3 (power, location, width)

  • default_min[0.0, 0.0, 0.0]

  • default_max[numpy.inf, numpy.inf, numpy.inf]

  • default_ramp_rate[None, None, None]

  • config_mapping – Maps to ECRH source parameters

  • state_var{'scalars': ['P_ecrh_e']} - modifies total electron-cyclotron power scalar

Action Parameters:
  • 0 – Total power (P_total) in Watts

  • 1 – Gaussian location (gaussian_location) - normalized radius [0,1]

  • 2 – Gaussian width (gaussian_width) - profile width parameter

Parameters:

Example

>>> ecrh_action = EcrhAction()
>>> ecrh_action._set_values([5e6, 0.3, 0.1])  # 5MW, r/a=0.3, width=0.1
name: str = 'ECRH'
dimension: int = 3
default_min: list[float] = [0.0, 0.0, 0.01]
default_max: list[float] = [inf, 1.0, inf]
default_ramp_rate: list[float | None] = [None, None, None]
config_mapping: dict[tuple[str, ...], int] = {('sources', 'ecrh', 'P_total'): (0, 1), ('sources', 'ecrh', 'gaussian_location'): (1, 1), ('sources', 'ecrh', 'gaussian_width'): (2, 1)}
state_var: dict[str, list[str]] = {'scalars': ['P_ecrh_e']}
class gymtorax.action_handler.NbiAction(nbi_w_to_ma=1 / 16e6, **kwargs)[source]

Bases: Action

Example action for controlling Neutral Beam Injection (NBI).

This action controls three NBI parameters: heating power, Gaussian location, and Gaussian width of the heating profile. The current drive power is automatically calculated from the heating power using a configurable conversion factor.

Class Attributes:
  • name"NBI"

  • dimension3 (heating power, location, width)

  • default_min[0.0, 0.0, 0.01]

  • default_max[numpy.inf, 1.0, numpy.inf]

  • default_ramp_rate[None, None, None]

  • config_mapping – Maps to generic heat and current source parameters in TORAX configuration

  • state_var{'scalars': ['P_aux_generic_total']} - modifies total auxiliary power scalar

Variables:
  • nbi_w_to_ma – Conversion factor from heating power (W) to current drive (MA). Default is 1/16e6, meaning 16MW of heating produces 1MA of current.

  • config_mapping (dict[tuple[str, ...], int]) – Dynamically created in __init__ to use the specified conversion factor.

Action Parameters:
  • 0 – Heating power (generic_heat P_total) in Watts

  • 1 – Gaussian location (shared by heat and current) - normalized radius [0,1]

  • 2 – Gaussian width (shared by heat and current) - profile width parameter

Example

>>> nbi_action = NbiAction()
>>> nbi_action._set_values([10e6, 0.4, 0.2])  # 10MW heating, r/a=0.4, width=0.2
>>> # NBI with custom conversion factor
>>> nbi_custom = NbiAction(nbi_w_to_ma=1/20e6)  # 20MW per 1MA
>>> nbi_custom._set_values([20e6, 0.3, 0.15])
>>> # NBI with current drive disabled
>>> nbi_heating_only = NbiAction(nbi_w_to_ma=0)
>>> nbi_heating_only._set_values([15e6, 0.5, 0.1])
name: str = 'NBI'
dimension: int = 3
default_min: list[float] = [0.0, 0.0, 0.01]
default_max: list[float] = [inf, 1.0, inf]
default_ramp_rate: list[float | None] = [None, None, None]
state_var: dict[str, list[str]] = {'scalars': ['P_aux_generic_total']}
__init__(nbi_w_to_ma=1 / 16e6, **kwargs)[source]

Initialize NbiAction with configurable heating-to-current conversion.

Parameters:
  • nbi_w_to_ma – Conversion factor from heating power (Watts) to current drive (MA). Default is 1/16e6, meaning 16MW of heating produces 1MA of current drive. Set to 0 to disable current drive while keeping heating.

  • **kwargs – Additional arguments passed to the parent Action class (min, max, ramp_rate, dtype).

Example

>>> # Default conversion (16MW -> 1MA)
>>> nbi = NbiAction()
>>> # Custom conversion (20MW -> 1MA)
>>> nbi = NbiAction(nbi_w_to_ma=1/20e6)
>>> # Heating only, no current drive
>>> nbi = NbiAction(nbi_w_to_ma=0)

Observation Handling

The observation system extracts relevant plasma state information and formats it for RL agents.

Base Observation Class

class gymtorax.observation_handler.Observation(variables: dict[str, list[str]] | None = None, custom_bounds_filename: str | None = None, exclude: dict[str, list[str]] | None = None, dtype: dtype = np.float64)[source]

Bases: ABC

Abstract base class for building observation spaces from TORAX DataTree outputs.

Converts TORAX simulation outputs into structured observation spaces for reinforcement learning environments. Handles variable selection, bounds specification, and automatic action variable filtering.

Variables:
  • variables_to_include (dict) – Variables to include in observation space.

  • variables_to_exclude (dict) – Variables to exclude from observation space.

  • custom_bounds (dict) – Custom bounds for variables.

  • dtype (numpy.dtype) – Data type for observation arrays.

  • action_variables (dict) – Variables controlled by actions.

  • state_variables (dict) – Available variables from TORAX output.

  • observation_variables (dict) – Final filtered observation variables.

  • bounds (dict) – Final bounds after processing.

Parameters:
__init__(variables: dict[str, list[str]] | None = None, custom_bounds_filename: str | None = None, exclude: dict[str, list[str]] | None = None, dtype: dtype = np.float64) None[source]

Initialize Observation handler.

Sets up configuration for the observation handler. Requires subsequent calls to set_state_variables(), set_action_variables(), and build_observation_space() before use.

Parameters:
  • variables (dict[str, list[str]] | None) – Variables to include. Format: {"profiles": [names], "scalars": [names]}. If None, includes all available variables except those in exclude.

  • custom_bounds_filename (str | None) – Path to JSON file with custom bounds. Format: {"profiles": {var: {"min": val, "max": val}}, "scalars": {...}}.

  • exclude (dict[str, list[str]] | None) – Variables to exclude. Format: {"profiles": [names], "scalars": [names]}. Cannot be used with variables parameter.

  • dtype (dtype) – Data type for observation arrays.

Raises:

ValueError – If both variables and exclude specified or invalid configuration.

Return type:

None

set_action_variables(variables: dict[str, list[str]]) None[source]

Set variables controlled by actions.

These variables are removed from the observation space to prevent redundancy between actions and observations.

Parameters:

variables (dict[str, list[str]]) – Action variables by category. Format: {"profiles": [names], "scalars": [names]}.

Return type:

None

set_state_variables(state: DataTree) None[source]

Set available state variables from TORAX output.

Catalogs all variables from the TORAX DataTree for inclusion in the observation space.

Parameters:

state (DataTree) – TORAX DataTree with /profiles/ and /scalars/ datasets.

Return type:

None

extract_state_observation(datatree: DataTree) tuple[dict[str, dict[str, ndarray]], dict[str, dict[str, ndarray]]][source]

Extract state and observation data from TORAX output.

Returns both complete state (all variables) and filtered observation (selected variables only).

Parameters:

datatree (DataTree) – TORAX simulation output containing profiles and scalars datasets.

Returns:

  • state (dict): Complete state with all variables.

    Format {"profiles": {var: array}, "scalars": {var: value}}

  • observation (dict): Filtered observation with selected variables.

    Format {"profiles": {var: array}, "scalars": {var: value}}

Return type:

tuple[dict[str, dict[str, numpy.ndarray]], dict[str, dict[str, numpy.ndarray]]]

build_observation_space() Dict[source]

Build Gymnasium observation space for selected variables.

Creates nested Dict space with Box spaces for each variable using configured bounds. Validates configuration and finalizes variable selection.

Returns:

Gymnasium Dict space with structure

{"profiles": {var: Box}, "scalars": {var: Box}}

Return type:

gymnasium.spaces.Dict

Raises:

ValueError – If validation fails or required setup incomplete.

Concrete Observations

class gymtorax.observation_handler.AllObservation(exclude=None, custom_bounds_file=None)[source]

Bases: Observation

Observation handler that includes all available TORAX variables.

Creates a complete observation space containing all profile and scalar variables available in the TORAX simulation output. Supports variable exclusions and custom bounds configuration.

Example

>>> obs = AllObservation()
>>> obs = AllObservation(exclude={"profiles": ["n_impurity"]})
>>> obs = AllObservation(custom_bounds_filename="bounds.json")
__init__(exclude=None, custom_bounds_file=None) None[source]

Initialize AllObservation with all available TORAX variables.

Creates an observation handler that includes all available TORAX variables by default, with flexible configuration through keyword arguments.

Parameters:
  • exclude (dict[str, list[str]] or None) – Variables to exclude. Format: {“profiles”: [names], “scalars”: [names]}.

  • custom_bounds_file (str or None) – Path to JSON file containing custom bounds for variables.

Return type:

None

Example

>>> obs = AllObservation()
>>> obs = AllObservation(exclude={"profiles": ["psi"]})

Reward Functions

Reward functions define the control objectives and translate plasma performance into scalar signals for RL training.

TORAX reward module.

This module provides functions to extract specific metrics from the state dictionary returned by the TORAX simulator. These metrics can be used to construct reward functions for reinforcement learning environments focused on tokamak control and optimization. Other reward functions can be created.

gymtorax.rewards.get_fusion_gain(state: dict) float[source]

Get the fusion gain \(Q\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing scalar values.

Returns:

The fusion gain \(Q\).

Return type:

float

gymtorax.rewards.get_beta_N(state: dict) float[source]

Get the normalized \(\beta_N\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing scalar values.

Returns:

The normalized \(\beta_N\).

Return type:

float

gymtorax.rewards.get_tau_E(state: dict) float[source]

Get the energy confinement time \(\tau_E\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing scalar values.

Returns:

The energy confinement time \(\tau_E\).

Return type:

float

gymtorax.rewards.get_h98(state: dict) float[source]

Get the H-mode confinement quality factor from the state dictionary.

Parameters:

state (dict) – The state dictionary containing scalar values.

Returns:

The \(H98\) factor.

Return type:

float

gymtorax.rewards.get_q_profile(state: dict) ndarray[source]

Get the safety factor profile \(q\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing profile values.

Returns:

The safety factor profile \(q\).

Return type:

numpy.ndarray

gymtorax.rewards.get_q_min(state: dict) float[source]

Get the minimum safety factor \(q_{min}\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing scalar values.

Returns:

The minimum safety factor \(q_{min}\).

Return type:

float

gymtorax.rewards.get_q95(state: dict) float[source]

Get safety factor at 95% of the normalized poloidal flux coordinate.

Parameters:

state (dict) – The state dictionary containing profile values.

Returns:

The safety factor at 95% of the normalized poloidal flux coordinate.

Return type:

float

gymtorax.rewards.get_s_profile(state: dict) ndarray[source]

Get the magnetic shear profile \(s\) from the state dictionary.

Parameters:

state (dict) – The state dictionary containing profile values.

Returns:

The magnetic shear profile \(s\).

Return type:

numpy.ndarray