TORAX Wrapper

The TORAX wrapper is a critical component that bridges between the Gymnasium environment interface and the TORAX plasma physics simulation engine. It handles state management, configuration, and provides a clean interface for reinforcement learning interactions.

torax_app

The main application class that orchestrates TORAX simulations.

High-level application interface for running TORAX plasma simulations.

This module provides the ToraxApp class, which wraps the TORAX simulator into a Pythonic interface suitable for reinforcement learning and episodic simulation workflows. It manages the simulation lifecycle, configuration updates, state tracking, and output handling.

This abstraction allows Gymnasium-style environments and control algorithms to interact with TORAX without dealing with its low-level orchestration details.

class gymtorax.torax_wrapper.torax_app.ToraxApp(config_loader: ConfigLoader, delta_t_a: float, store_history: bool = False)[source]

Bases: object

TORAX simulation application wrapper.

This class provides a high-level interface for running TORAX plasma simulations in an episodic manner, suitable for reinforcement learning environments. It manages the simulation lifecycle, state tracking, and configuration updates.

The application follows a start/reset -> run -> update cycle:
  1. Initialize with configuration and action timestep

  2. Call reset() to prepare for a new episode

  3. Call run() repeatedly to advance the simulation

  4. Call update_config() between runs to update action parameters

Variables:
  • config (ConfigLoader) – Current configuration loader instance

  • initial_config (ConfigLoader) – Original configuration for resetting

  • delta_t_a (float) – Action timestep - simulation duration per run()

  • store_history (bool) – Whether to store complete simulation history

  • current_sim_state (ToraxSimState) – Current simulation state

  • current_sim_output (PostProcessedOutputs) – Current post-processed outputs

  • state (StateHistory) – Current state history (single timestep)

  • history_list (list) – Complete history list (if store_history=True)

  • is_started (bool) – Whether the application has been initialized

  • t_current (float) – Current simulation time

  • t_final (float) – Final simulation time for current episode

  • last_run_time (float) – Timestamp of last run() call (for performance monitoring)

Parameters:
__init__(config_loader: ConfigLoader, delta_t_a: float, store_history: bool = False)[source]

Initialize ToraxApp with configuration and simulation parameters.

Parameters:
  • config_loader (ConfigLoader) – ConfigLoader instance containing TORAX configuration

  • delta_t_a (float) – Action timestep in seconds. Each call to run() advances simulation by this amount

  • store_history (bool) – If True, stores complete simulation history for later analysis. If False, only keeps current state (more memory efficient)

Note

The application must be reset() before first use. The constructor only sets up instance variables and enables performance monitoring if debug logging is enabled.

start()[source]

Initialize TORAX simulation components.

This method sets up all the TORAX simulation infrastructure:
  • Transport and pedestal models

  • Geometry provider and source models

  • Static and dynamic runtime parameters

  • Solver and MHD models

  • Step function for simulation advancement

  • Initial simulation state and outputs

Called automatically by reset() if not already started.

reset()[source]

Reset the simulation to initial conditions for a new episode.

This method prepares the application for a new simulation episode by:
  • Initializing TORAX components if not already started

  • Resetting simulation state to initial conditions

  • Creating fresh state history

  • Setting up time tracking (t_current=0, t_final from config)

  • Configuring first action step duration

run() tuple[bool, bool][source]

Execute one simulation step from t_current to t_current + delta_t_a.

This method advances the TORAX simulation by one action timestep, which may involve multiple internal TORAX timesteps. It handles:

  • Performance timing (if debug logging enabled)

  • TORAX run_loop execution with current configuration

  • State and output management

  • Error handling and recovery

  • Time progression tracking

Returns:

  • success (bool): True if simulation step completed successfully,

    False if an error occurred or simulation reached final time.

  • done (bool): True if whole simulation is done.

Return type:

tuple[bool, bool]

Raises:

RuntimeError – If reset() has not been called before running.

Note

  • Call update_config() between runs to modify simulation parameters

  • Returns True when t_current >= t_final (episode complete)

  • Performance timing logged at DEBUG level shows interval since last run

  • Errors during simulation return False (environment should reset)

update_config(action) None[source]

Update simulation configuration with new action parameters.

This method applies new control parameters to the TORAX configuration for the next simulation step.

Parameters:

action – Action dictionary containing new parameter values. Must match the format expected by the ConfigLoader.

Raises:

ValueError – If action format is invalid or configuration update fails.

Return type:

None

get_output_datatree(start: int = 0, end: int = -1) DataTree[source]

Return the full simulation history as an xarray DataTree.

This method reconstructs the complete trajectory of the simulation, including all state and post-processed output snapshots, as an xarray DataTree suitable for analysis and visualization. If beginning and end are specified, only data between those time values (inclusive) will be selected for all datasets in the DataTree that have a ‘time’ coordinate. Requires that the ToraxApp was initialized with store_history=True so that the full history is available.

Parameters:
  • start (int or float) – Start time for selection. Defaults to 0.

  • end (int or float) – End time for selection. Defaults to -1 (no upper limit).

Returns:

The complete simulation history as an xarray DataTree,

with all timesteps and outputs, or only the selected time range if specified.

Return type:

xarray.DataTree

Raises:

RuntimeError – If store_history was not enabled and thus no history is available.

save_output_file(file_name)[source]

Save complete simulation history to NetCDF file.

This method saves the full simulation trajectory to a NetCDF file suitable for analysis and visualization. Requires store_history=True in constructor.

Parameters:

file_name (str) – Output file path with .nc extension

Raises:
get_state_data()[source]

Get current simulation state as xarray DataTree.

This method returns the current simulation state in xarray format, suitable for observation extraction and analysis.

Returns:

Current simulation state.

Return type:

xarray.DataTree

Raises:

RuntimeError – If simulation state has not been computed yet.

Note

  • Returns single-timestep state (current moment)

  • For full history, use save_output_file() with store_history=True

config_loader

Configuration system for managing TORAX physics parameters and simulation settings.

Configuration loader for TORAX simulation package.

This module provides a wrapper around TORAX configuration dictionaries, offering convenient access to common simulation parameters and configuration management for Gymnasium environments.

class gymtorax.torax_wrapper.config_loader.ConfigLoader(config: dict[str, Any], action_handler: ActionHandler)[source]

Bases: object

A wrapper class for TORAX configuration management.

This class handles the conversion between Python dictionaries and TORAX’s internal configuration format, providing convenient access to simulation parameters commonly needed in Gymnasium environments.

Parameters:
__init__(config: dict[str, Any], action_handler: ActionHandler)[source]

Initialize the configuration loader.

Parameters:
  • config (dict[str, Any]) – Dictionary containing TORAX configuration parameters.

  • action_handler (ActionHandler) – ActionHandler instance for managing actions.

Raises:
  • ValueError – If the configuration dictionary is invalid

  • TypeError – If config is not a dictionary

get_dict() dict[str, Any][source]

Get the raw configuration dictionary.

Returns:

The original configuration dictionary

Return type:

dict[str, Any]

get_total_simulation_time() float[source]

Get the total simulation time in seconds.

This extracts the t_final parameter from the numerics section, which defines how long the plasma simulation should run.

Returns:

Total simulation time in seconds

Raises:
  • KeyError – If the configuration does not contain the required keys

  • TypeError – If the value is not a number

Return type:

float

set_total_simulation_time(time: float) None[source]

Set the total simulation time in seconds.

This updates the t_final parameter in the numerics section, which defines how long the plasma simulation should run.

Parameters:

time (float) – Total simulation time in seconds

Raises:
  • KeyError – If the configuration does not contain the required keys

  • TypeError – If the value is not a number

Return type:

None

get_initial_simulation_time(restart=False) float[source]

Get the initial simulation time in seconds.

This extracts the t_initial parameter from the numerics section, which defines the initial time for the plasma simulation.

Returns:

Total simulation time in seconds

Raises:
  • KeyError – If the configuration does not contain the required keys

  • TypeError – If the value is not a number

Return type:

float

get_simulation_timestep() float[source]

Get the simulation timestep in seconds.

This extracts the fixed_dt parameter from the numerics section, which defines the time step used in the numerical integration.

Returns:

Simulation timestep in seconds

Raises:
  • KeyError – If the configuration does not contain the required keys

  • TypeError – If the value is not a number

Return type:

float

get_n_grid_points() int[source]

Get the number of radial grid points (rho) in the simulation.

This extracts the n_rho parameter from the geometry section, which defines the number of radial grid points in the simulation. If the parameter is not set, a default value of 25 will be used, in accordance to TORAX settings.

Returns:

Number of radial grid points (rho)

Raises:

TypeError – If the value is not an integer

Return type:

int

update_config(action, current_time: float, delta_t_a: float) None[source]

Update the simulation configuration with new timing and action parameters.

This method updates the TORAX configuration with new time boundaries and applies the provided action through the action handler. It handles time stepping and rebuilds the TORAX config.

Parameters:
  • action – Action values to be applied through the action handler.

  • current_time (float) – The current simulation time in seconds.

  • delta_t_a (float) – The action duration/time step in seconds.

Raises:

ValueError – If Ip control is requested but Ip_from_parameters is False.

Return type:

None

get_current_action_values() dict[str, Any][source]

Get the current action values from the action handler.

Returns:

Dictionary of current action values

Return type:

dict[str, Any]

validate_discretization(discretization_torax: str) None[source]

Validate the discretization settings.

This method checks that the discretization settings are consistent and valid for the simulation.

Raises:

ValueError – If the discretization settings are invalid

Parameters:

discretization_torax (str)

Return type:

None

torax_plot_helpers

TORAX plotting helper functions for visualization.

This module provides utilities for creating matplotlib figures and updating plots with TORAX simulation data. The functions are designed to work with TORAX plotting system while supporting both static image generation and real-time visualization updates.

Key functions:
  • create_figure(): Sets up matplotlib figure with TORAX styling and font scaling

  • update_lines(): Updates plot lines with simulation data (spatial profiles or time series)

  • validate_plotdata(): Ensures plot configuration matches available data attributes

  • load_data(): Processes TORAX DataTree output into PlotData format with unit conversions

All of these functions are adapted from TORAX plotruns_lib module, with modifications to be able to apply them in the GymTORAX environments.

gymtorax.torax_wrapper.torax_plot_helpers.create_figure(plot_config: torax._src.plotting.plotruns_lib.FigureProperties, font_scale: float = 1)[source]

Create matplotlib figure with TORAX styling and configurable font scaling.

Sets up a matplotlib figure using TORAX plot configuration, applies matplotlib RC settings for consistent styling, and creates a grid of subplots. Font sizes are scaled according to the font_scale parameter and applied to the plot_config object in-place. As side effects, this function modifies matplotlib global RC settings for tick, axes, and figure fonts, and modifies plot_config.default_legend_fontsize and axes legend_fontsize in-place.

Parameters:
  • plot_config (plotruns_lib.FigureProperties) – TORAX plot configuration containing subplot layout (rows, cols), font sizes, figure size factor, and axes configurations. Modified in-place to apply font scaling.

  • font_scale (float) – Multiplier for all font sizes. Applied to tick labels, axis labels, titles, and legend fonts. Defaults to 1.0.

Returns:

  • fig (matplotlib.figure.Figure): Figure object

  • axes (list[matplotlib.axes.Axes]): list of axes in row-major order (left-to-right, top-to-bottom).

Return type:

tuple[matplotlib.figure.Figure, list[matplotlib.axes.Axes]]

gymtorax.torax_wrapper.torax_plot_helpers.update_lines(lines, axes, plot_config, plotdata, t, first_update)[source]

Update or create plot lines with simulation data.

As side effects, this function sets cfg.include_first_timepoint = True on each axis config, and for TIME_SERIES on subsequent updates, appends data to existing line coordinates.

Parameters:
  • lines (list) – Existing matplotlib Line2D objects. Empty on first call.

  • axes (list) – Matplotlib axes objects matching plot_config layout.

  • plot_config (plotruns_lib.FigureProperties) – Defines subplot configurations, each with plot_type, attrs (variable names), labels, and colors.

  • plotdata (plotruns_lib.PlotData) – Simulation data with plasma variables.

  • t (float) – Current simulation time (used for TIME_SERIES updates).

  • first_update (bool) – If True, creates new lines; if False, updates existing.

Returns:

Updated list of Line2D objects for future calls.

Return type:

list

Raises:

ValueError – If plot_type is not SPATIAL or TIME_SERIES.

Note

Uses plotruns_lib.get_rho() to determine x-coordinate for spatial plots. Color cycling follows plot_config.colors list with modulo indexing.

gymtorax.torax_wrapper.torax_plot_helpers.validate_plotdata(plotdata: torax._src.plotting.plotruns_lib.PlotData, plot_config: torax._src.plotting.plotruns_lib.FigureProperties)[source]

Check that all plot configuration attributes exist in plotdata.

Uses introspection to find all available attributes in the PlotData object (both dataclass fields and properties), then verifies that every attribute name listed in the plot configuration axes.attrs lists exists.

Parameters:
  • plotdata (plotruns_lib.PlotData) – Data object to check.

  • plot_config (plotruns_lib.FigureProperties) – Plot configuration with axes definitions. Each axis config has an attrs list of variable names.

Raises:

ValueError – If any attribute in plot_config.axes[*].attrs is not found in plotdata. Error message identifies the missing attribute name.

gymtorax.torax_wrapper.torax_plot_helpers.load_data(data_tree: DataTree) torax._src.plotting.plotruns_lib.PlotData[source]

Convert TORAX DataTree output to PlotData with unit transformations.

Extracts time coordinate and applies unit conversions to match TORAX plotting conventions (A/m² → MA/m², W → MW, m⁻³ → 10²⁰ m⁻³, etc.). Handles hierarchical DataTree structure by extracting from profiles/ and scalars/ branches.

Parameters:

data_tree (xarray.DataTree) – TORAX simulation output.

Returns:

Object with plasma variables in plotting units.

Return type:

plotruns_lib.PlotData