TORAX Wrapper
The TORAX wrapper is a critical component that bridges between the Gymnasium environment interface and the TORAX plasma physics simulation engine. It handles state management, configuration, and provides a clean interface for reinforcement learning interactions.
torax_app
The main application class that orchestrates TORAX simulations.
High-level application interface for running TORAX plasma simulations.
This module provides the ToraxApp class, which wraps the TORAX simulator into a Pythonic interface suitable for reinforcement learning and episodic simulation workflows. It manages the simulation lifecycle, configuration updates, state tracking, and output handling.
This abstraction allows Gymnasium-style environments and control algorithms to interact with TORAX without dealing with its low-level orchestration details.
- class gymtorax.torax_wrapper.torax_app.ToraxApp(config_loader: ConfigLoader, delta_t_a: float, store_history: bool = False)[source]
Bases:
objectTORAX simulation application wrapper.
This class provides a high-level interface for running TORAX plasma simulations in an episodic manner, suitable for reinforcement learning environments. It manages the simulation lifecycle, state tracking, and configuration updates.
- The application follows a start/reset -> run -> update cycle:
Initialize with configuration and action timestep
Call
reset()to prepare for a new episodeCall
run()repeatedly to advance the simulationCall
update_config()between runs to update action parameters
- Variables:
config (ConfigLoader) – Current configuration loader instance
initial_config (ConfigLoader) – Original configuration for resetting
delta_t_a (float) – Action timestep - simulation duration per
run()store_history (bool) – Whether to store complete simulation history
current_sim_state (ToraxSimState) – Current simulation state
current_sim_output (PostProcessedOutputs) – Current post-processed outputs
state (StateHistory) – Current state history (single timestep)
history_list (list) – Complete history list (if
store_history=True)is_started (bool) – Whether the application has been initialized
t_current (float) – Current simulation time
t_final (float) – Final simulation time for current episode
last_run_time (float) – Timestamp of last
run()call (for performance monitoring)
- Parameters:
config_loader (ConfigLoader)
delta_t_a (float)
store_history (bool)
- __init__(config_loader: ConfigLoader, delta_t_a: float, store_history: bool = False)[source]
Initialize ToraxApp with configuration and simulation parameters.
- Parameters:
config_loader (ConfigLoader) – ConfigLoader instance containing TORAX configuration
delta_t_a (float) – Action timestep in seconds. Each call to
run()advances simulation by this amountstore_history (bool) – If
True, stores complete simulation history for later analysis. IfFalse, only keeps current state (more memory efficient)
Note
The application must be
reset()before first use. The constructor only sets up instance variables and enables performance monitoring if debug logging is enabled.
- start()[source]
Initialize TORAX simulation components.
- This method sets up all the TORAX simulation infrastructure:
Transport and pedestal models
Geometry provider and source models
Static and dynamic runtime parameters
Solver and MHD models
Step function for simulation advancement
Initial simulation state and outputs
Called automatically by
reset()if not already started.
- reset()[source]
Reset the simulation to initial conditions for a new episode.
- This method prepares the application for a new simulation episode by:
Initializing TORAX components if not already started
Resetting simulation state to initial conditions
Creating fresh state history
Setting up time tracking (t_current=0, t_final from config)
Configuring first action step duration
- run() tuple[bool, bool][source]
Execute one simulation step from t_current to t_current + delta_t_a.
This method advances the TORAX simulation by one action timestep, which may involve multiple internal TORAX timesteps. It handles:
Performance timing (if debug logging enabled)
TORAX run_loop execution with current configuration
State and output management
Error handling and recovery
Time progression tracking
- Returns:
- success (bool): True if simulation step completed successfully,
False if an error occurred or simulation reached final time.
done (bool): True if whole simulation is done.
- Return type:
- Raises:
RuntimeError – If reset() has not been called before running.
Note
Call update_config() between runs to modify simulation parameters
Returns True when t_current >= t_final (episode complete)
Performance timing logged at DEBUG level shows interval since last run
Errors during simulation return False (environment should reset)
- update_config(action) None[source]
Update simulation configuration with new action parameters.
This method applies new control parameters to the TORAX configuration for the next simulation step.
- Parameters:
action – Action dictionary containing new parameter values. Must match the format expected by the ConfigLoader.
- Raises:
ValueError – If action format is invalid or configuration update fails.
- Return type:
None
- get_output_datatree(start: int = 0, end: int = -1) DataTree[source]
Return the full simulation history as an xarray DataTree.
This method reconstructs the complete trajectory of the simulation, including all state and post-processed output snapshots, as an xarray DataTree suitable for analysis and visualization. If beginning and end are specified, only data between those time values (inclusive) will be selected for all datasets in the DataTree that have a ‘time’ coordinate. Requires that the ToraxApp was initialized with store_history=True so that the full history is available.
- Parameters:
- Returns:
- The complete simulation history as an xarray DataTree,
with all timesteps and outputs, or only the selected time range if specified.
- Return type:
- Raises:
RuntimeError – If
store_historywas not enabled and thus no history is available.
- save_output_file(file_name)[source]
Save complete simulation history to NetCDF file.
This method saves the full simulation trajectory to a NetCDF file suitable for analysis and visualization. Requires
store_history=Truein constructor.- Parameters:
file_name (str) – Output file path with .nc extension
- Raises:
RuntimeError – If
store_history=False(no history to save)ValueError – If file writing fails
- get_state_data()[source]
Get current simulation state as xarray DataTree.
This method returns the current simulation state in xarray format, suitable for observation extraction and analysis.
- Returns:
Current simulation state.
- Return type:
- Raises:
RuntimeError – If simulation state has not been computed yet.
Note
Returns single-timestep state (current moment)
For full history, use
save_output_file()withstore_history=True
config_loader
Configuration system for managing TORAX physics parameters and simulation settings.
Configuration loader for TORAX simulation package.
This module provides a wrapper around TORAX configuration dictionaries, offering convenient access to common simulation parameters and configuration management for Gymnasium environments.
- class gymtorax.torax_wrapper.config_loader.ConfigLoader(config: dict[str, Any], action_handler: ActionHandler)[source]
Bases:
objectA wrapper class for TORAX configuration management.
This class handles the conversion between Python dictionaries and TORAX’s internal configuration format, providing convenient access to simulation parameters commonly needed in Gymnasium environments.
- Parameters:
action_handler (ActionHandler)
- __init__(config: dict[str, Any], action_handler: ActionHandler)[source]
Initialize the configuration loader.
- Parameters:
config (dict[str, Any]) – Dictionary containing TORAX configuration parameters.
action_handler (ActionHandler) – ActionHandler instance for managing actions.
- Raises:
ValueError – If the configuration dictionary is invalid
TypeError – If config is not a dictionary
- get_total_simulation_time() float[source]
Get the total simulation time in seconds.
This extracts the
t_finalparameter from the numerics section, which defines how long the plasma simulation should run.
- set_total_simulation_time(time: float) None[source]
Set the total simulation time in seconds.
This updates the
t_finalparameter in the numerics section, which defines how long the plasma simulation should run.
- get_initial_simulation_time(restart=False) float[source]
Get the initial simulation time in seconds.
This extracts the
t_initialparameter from the numerics section, which defines the initial time for the plasma simulation.
- get_simulation_timestep() float[source]
Get the simulation timestep in seconds.
This extracts the
fixed_dtparameter from the numerics section, which defines the time step used in the numerical integration.
- get_n_grid_points() int[source]
Get the number of radial grid points (rho) in the simulation.
This extracts the
n_rhoparameter from the geometry section, which defines the number of radial grid points in the simulation. If the parameter is not set, a default value of25will be used, in accordance to TORAX settings.
- update_config(action, current_time: float, delta_t_a: float) None[source]
Update the simulation configuration with new timing and action parameters.
This method updates the TORAX configuration with new time boundaries and applies the provided action through the action handler. It handles time stepping and rebuilds the TORAX config.
- Parameters:
- Raises:
ValueError – If Ip control is requested but Ip_from_parameters is False.
- Return type:
None
- get_current_action_values() dict[str, Any][source]
Get the current action values from the action handler.
- validate_discretization(discretization_torax: str) None[source]
Validate the discretization settings.
This method checks that the discretization settings are consistent and valid for the simulation.
- Raises:
ValueError – If the discretization settings are invalid
- Parameters:
discretization_torax (str)
- Return type:
None
torax_plot_helpers
TORAX plotting helper functions for visualization.
This module provides utilities for creating matplotlib figures and updating plots with TORAX simulation data. The functions are designed to work with TORAX plotting system while supporting both static image generation and real-time visualization updates.
- Key functions:
create_figure(): Sets up matplotlib figure with TORAX styling and font scaling
update_lines(): Updates plot lines with simulation data (spatial profiles or time series)
validate_plotdata(): Ensures plot configuration matches available data attributes
load_data(): Processes TORAX DataTree output into PlotData format with unit conversions
All of these functions are adapted from TORAX plotruns_lib module, with modifications
to be able to apply them in the GymTORAX environments.
- gymtorax.torax_wrapper.torax_plot_helpers.create_figure(plot_config: torax._src.plotting.plotruns_lib.FigureProperties, font_scale: float = 1)[source]
Create matplotlib figure with TORAX styling and configurable font scaling.
Sets up a matplotlib figure using TORAX plot configuration, applies matplotlib RC settings for consistent styling, and creates a grid of subplots. Font sizes are scaled according to the font_scale parameter and applied to the plot_config object in-place. As side effects, this function modifies matplotlib global RC settings for tick, axes, and figure fonts, and modifies
plot_config.default_legend_fontsizeand axeslegend_fontsizein-place.- Parameters:
plot_config (plotruns_lib.FigureProperties) – TORAX plot configuration containing subplot layout (rows, cols), font sizes, figure size factor, and axes configurations. Modified in-place to apply font scaling.
font_scale (float) – Multiplier for all font sizes. Applied to tick labels, axis labels, titles, and legend fonts. Defaults to
1.0.
- Returns:
fig (matplotlib.figure.Figure): Figure object
axes (list[matplotlib.axes.Axes]): list of axes in row-major order (left-to-right, top-to-bottom).
- Return type:
- gymtorax.torax_wrapper.torax_plot_helpers.update_lines(lines, axes, plot_config, plotdata, t, first_update)[source]
Update or create plot lines with simulation data.
As side effects, this function sets
cfg.include_first_timepoint = Trueon each axis config, and for TIME_SERIES on subsequent updates, appends data to existing line coordinates.- Parameters:
lines (list) – Existing matplotlib Line2D objects. Empty on first call.
axes (list) – Matplotlib axes objects matching plot_config layout.
plot_config (plotruns_lib.FigureProperties) – Defines subplot configurations, each with plot_type, attrs (variable names), labels, and colors.
plotdata (plotruns_lib.PlotData) – Simulation data with plasma variables.
t (float) – Current simulation time (used for TIME_SERIES updates).
first_update (bool) – If True, creates new lines; if False, updates existing.
- Returns:
Updated list of Line2D objects for future calls.
- Return type:
- Raises:
ValueError – If plot_type is not SPATIAL or TIME_SERIES.
Note
Uses
plotruns_lib.get_rho()to determine x-coordinate for spatial plots. Color cycling followsplot_config.colorslist with modulo indexing.
- gymtorax.torax_wrapper.torax_plot_helpers.validate_plotdata(plotdata: torax._src.plotting.plotruns_lib.PlotData, plot_config: torax._src.plotting.plotruns_lib.FigureProperties)[source]
Check that all plot configuration attributes exist in plotdata.
Uses introspection to find all available attributes in the PlotData object (both dataclass fields and properties), then verifies that every attribute name listed in the plot configuration
axes.attrslists exists.- Parameters:
plotdata (plotruns_lib.PlotData) – Data object to check.
plot_config (plotruns_lib.FigureProperties) – Plot configuration with axes definitions. Each axis config has an
attrslist of variable names.
- Raises:
ValueError – If any attribute in
plot_config.axes[*].attrsis not found in plotdata. Error message identifies the missing attribute name.
- gymtorax.torax_wrapper.torax_plot_helpers.load_data(data_tree: DataTree) torax._src.plotting.plotruns_lib.PlotData[source]
Convert TORAX DataTree output to PlotData with unit transformations.
Extracts time coordinate and applies unit conversions to match TORAX plotting conventions (A/m² → MA/m², W → MW, m⁻³ → 10²⁰ m⁻³, etc.). Handles hierarchical DataTree structure by extracting from
profiles/andscalars/branches.- Parameters:
data_tree (xarray.DataTree) – TORAX simulation output.
- Returns:
Object with plasma variables in plotting units.
- Return type:
plotruns_lib.PlotData