Source code for gymtorax.torax_wrapper.torax_app

"""High-level application interface for running TORAX plasma simulations.

This module provides the `ToraxApp` class, which wraps the TORAX simulator
into a Pythonic interface suitable for reinforcement learning and episodic
simulation workflows. It manages the simulation lifecycle, configuration updates,
state tracking, and output handling.

This abstraction allows Gymnasium-style environments and control algorithms
to interact with TORAX without dealing with its low-level orchestration details.
"""

import copy
import logging
import time

from torax._src import state
from torax._src.config import build_runtime_params
from torax._src.orchestration import initial_state as initial_state_lib
from torax._src.orchestration import run_loop, step_function
from torax._src.orchestration.sim_state import ToraxSimState
from torax._src.output_tools import output
from torax._src.output_tools.post_processing import PostProcessedOutputs
from torax._src.sources import source_models as source_models_lib
from torax._src.state import SimError
from xarray import DataTree

from .config_loader import ConfigLoader

# Set up logger for this module
logger = logging.getLogger(__name__)


[docs] class ToraxApp: """TORAX simulation application wrapper. This class provides a high-level interface for running TORAX plasma simulations in an episodic manner, suitable for reinforcement learning environments. It manages the simulation lifecycle, state tracking, and configuration updates. The application follows a start/reset -> run -> update cycle: 1. Initialize with configuration and action timestep 2. Call ``reset()`` to prepare for a new episode 3. Call ``run()`` repeatedly to advance the simulation 4. Call ``update_config()`` between runs to update action parameters Attributes: config (ConfigLoader): Current configuration loader instance initial_config (ConfigLoader): Original configuration for resetting delta_t_a (float): Action timestep - simulation duration per ``run()`` store_history (bool): Whether to store complete simulation history current_sim_state (ToraxSimState): Current simulation state current_sim_output (PostProcessedOutputs): Current post-processed outputs state (StateHistory): Current state history (single timestep) history_list (list): Complete history list (if ``store_history=True``) is_started (bool): Whether the application has been initialized t_current (float): Current simulation time t_final (float): Final simulation time for current episode last_run_time (float): Timestamp of last ``run()`` call (for performance monitoring) """
[docs] def __init__( self, config_loader: ConfigLoader, delta_t_a: float, store_history: bool = False ): """Initialize ToraxApp with configuration and simulation parameters. Args: config_loader: `ConfigLoader` instance containing TORAX configuration delta_t_a: Action timestep in seconds. Each call to ``run()`` advances simulation by this amount store_history: If ``True``, stores complete simulation history for later analysis. If ``False``, only keeps current state (more memory efficient) Note: The application must be ``reset()`` before first use. The constructor only sets up instance variables and enables performance monitoring if debug logging is enabled. """ # Store configuration and simulation parameters self.store_history = store_history self.initial_config: ConfigLoader = config_loader self.delta_t_a = delta_t_a # Initialize state containers (will be populated by reset()) self.current_sim_state: ToraxSimState | None = None self.current_sim_output: PostProcessedOutputs | None = None self.state: output.StateHistory | None = ( None # Current state history (single timestep) ) # Optional full history storage for analysis/debugging if self.store_history is True: self.history_list: list = [] # Performance monitoring (only if debug logging enabled) if logger.isEnabledFor(logging.DEBUG): self.last_run_time = None # Track initialization state self.is_started: bool = False
[docs] def start(self): """Initialize TORAX simulation components. This method sets up all the TORAX simulation infrastructure: - Transport and pedestal models - Geometry provider and source models - Static and dynamic runtime parameters - Solver and MHD models - Step function for simulation advancement - Initial simulation state and outputs Called automatically by ``reset()`` if not already started. """ # Build physics models for plasma simulation transport_model = ( self.initial_config.config_torax.transport.build_transport_model() ) pedestal_model = ( self.initial_config.config_torax.pedestal.build_pedestal_model() ) # Geometry provider defines the spatial grid and magnetic geometry self.geometry_provider = ( self.initial_config.config_torax.geometry.build_provider ) # Source models handle heating, current drive, and other plasma sources source_models = source_models_lib.SourceModels( self.initial_config.config_torax.sources, neoclassical=self.initial_config.config_torax.neoclassical, ) # Static runtime parameters (do not change during simulation) self.static_runtime_params_slice = ( build_runtime_params.build_static_params_from_config( self.initial_config.config_torax ) ) # Build the main physics solver (handles transport equations) solver = self.initial_config.config_torax.solver.build_solver( static_runtime_params_slice=self.static_runtime_params_slice, transport_model=transport_model, source_models=source_models, pedestal_model=pedestal_model, ) # MHD models for magnetohydrodynamic instabilities and limits mhd_models = self.initial_config.config_torax.mhd.build_mhd_models( static_runtime_params_slice=self.static_runtime_params_slice, transport_model=transport_model, source_models=source_models, pedestal_model=pedestal_model, ) # Main simulation step function - advances physics by one timestep self.step_fn = step_function.SimulationStepFn( solver=solver, time_step_calculator=self.initial_config.config_torax.time_step_calculator.time_step_calculator, transport_model=transport_model, pedestal_model=pedestal_model, mhd_models=mhd_models, ) # Dynamic runtime parameters (can change during simulation via actions) self.dynamic_runtime_params_slice_provider = ( build_runtime_params.DynamicRuntimeParamsSliceProvider.from_config( self.initial_config.config_torax ) ) if ( self.initial_config.config_torax.restart and self.initial_config.config_torax.restart.do_restart ): self.initial_sim_state, self.initial_sim_output = ( initial_state_lib.get_initial_state_and_post_processed_outputs_from_file( t_initial=self.initial_config.config_torax.numerics.t_initial, file_restart=self.initial_config.config_torax.restart, static_runtime_params_slice=self.static_runtime_params_slice, dynamic_runtime_params_slice_provider=self.dynamic_runtime_params_slice_provider, geometry_provider=self.geometry_provider, step_fn=self.step_fn, ) ) else: self.initial_sim_state, self.initial_sim_output = ( initial_state_lib.get_initial_state_and_post_processed_outputs( t=self.initial_config.config_torax.numerics.t_initial, static_runtime_params_slice=self.static_runtime_params_slice, dynamic_runtime_params_slice_provider=self.dynamic_runtime_params_slice_provider, geometry_provider=self.geometry_provider, step_fn=self.step_fn, ) ) # Create state history container with initial state state_history = output.StateHistory( state_history=[self.initial_sim_state], post_processed_outputs_history=[self.initial_sim_output], sim_error=SimError.NO_ERROR, torax_config=self.initial_config.config_torax, ) # Update state and configuration references self.state = state_history # Mark as initialized self.is_started = True logger.debug(" ToraxApp started.")
[docs] def reset(self): """Reset the simulation to initial conditions for a new episode. This method prepares the application for a new simulation episode by: - Initializing TORAX components if not already started - Resetting simulation state to initial conditions - Creating fresh state history - Setting up time tracking (t_current=0, t_final from config) - Configuring first action step duration """ # Initialize TORAX physics models if not already done if self.is_started is False: self.start() # Store initial state in history if history tracking is enabled if self.store_history is True: self.history_list: list = [] self.history_list.append((self.initial_sim_state, self.initial_sim_output)) # Reset current simulation state to initial conditions self.current_sim_state = self.initial_sim_state self.current_sim_output = self.initial_sim_output # Rebuild geometry provider with initial configuration # This handles geometry changes (e.g., current profile modifications) self.geometry_provider = ( self.initial_config.config_torax.geometry.build_provider ) # Rebuild dynamic runtime parameters provider with initial configuration # This ensures time-dependent parameters reflect the updated config self.dynamic_runtime_params_slice_provider = ( build_runtime_params.DynamicRuntimeParamsSliceProvider.from_config( self.initial_config.config_torax ) ) # Create state history container with initial state state_history = output.StateHistory( state_history=[self.current_sim_state], post_processed_outputs_history=[self.current_sim_output], sim_error=SimError.NO_ERROR, torax_config=self.initial_config.config_torax, ) # Update state and configuration references self.state = state_history self.config = copy.deepcopy(self.initial_config) # Reset time tracking to beginning of episode if ( self.initial_config.config_torax.restart and self.initial_config.config_torax.restart.do_restart ): self.t_current = self.config.get_initial_simulation_time(restart=True) else: self.t_current = self.config.get_initial_simulation_time() self.t_final = self.config.get_total_simulation_time() # Set simulation end time to first action timestep self.config.set_total_simulation_time( self.delta_t_a ) # End for the first action step logger.debug(" ToraxApp reset.")
[docs] def run(self) -> tuple[bool, bool]: """Execute one simulation step from t_current to t_current + delta_t_a. This method advances the TORAX simulation by one action timestep, which may involve multiple internal TORAX timesteps. It handles: - Performance timing (if debug logging enabled) - TORAX run_loop execution with current configuration - State and output management - Error handling and recovery - Time progression tracking Returns: tuple[bool, bool]: - success (bool): True if simulation step completed successfully, False if an error occurred or simulation reached final time. - done (bool): True if whole simulation is done. Raises: RuntimeError: If reset() has not been called before running. Note: - Call update_config() between runs to modify simulation parameters - Returns True when `t_current >= t_final` (episode complete) - Performance timing logged at DEBUG level shows interval since last run - Errors during simulation return False (environment should reset) """ # Ensure simulation has been initialized if self.is_started is False: raise RuntimeError( "ToraxApp must be started before running the simulation." ) try: # Performance monitoring - track time between runs if logger.isEnabledFor(logging.DEBUG): current_time = time.perf_counter() interval = ( current_time - self.last_run_time if self.last_run_time is not None else 0 ) logger.debug( f" running simulation step at {self.t_current}/{self.t_final}s." ) logger.debug(f" time since last run: {interval:.2f} seconds.") # Update timing reference for next run if logger.isEnabledFor(logging.DEBUG): self.last_run_time = current_time # Execute TORAX simulation loop for one action timestep # This may involve multiple internal physics timesteps sim_states_list, post_processed_outputs_list, sim_error = run_loop.run_loop( static_runtime_params_slice=self.static_runtime_params_slice, dynamic_runtime_params_slice_provider=self.dynamic_runtime_params_slice_provider, geometry_provider=self.geometry_provider, initial_state=self.current_sim_state, initial_post_processed_outputs=self.current_sim_output, restart_case=True, # Continue from current state step_fn=self.step_fn, log_timestep_info=False, # Suppress internal TORAX logging progress_bar=False, # No progress bar for individual steps ) except Exception as e: logger.error( f" an error occurred during the simulation run: {e}. The environment will reset" ) return False, False # Check if TORAX simulation encountered internal errors if sim_error != state.SimError.NO_ERROR: logger.error("simulation terminated with an error.") sim_error.log_error() logger.error(" The environment will reset.") return False, False # Update current state to final state from simulation step self.current_sim_state = sim_states_list[-1] self.current_sim_output = post_processed_outputs_list[-1] # Store results in history if history tracking is enabled if self.store_history is True: self.history_list.append( [sim_states_list[-1], post_processed_outputs_list[-1]] ) # Update state history container with new results self.state = output.StateHistory( state_history=[self.current_sim_state], post_processed_outputs_history=[self.current_sim_output], sim_error=sim_error, torax_config=self.config.config_torax, ) # Advance time by one action timestep self.t_current += self.delta_t_a # Check if we have reached the end of the episode if self.t_current > self.t_final: logger.debug(" simulation run terminated successfully.") return True, True return True, False
[docs] def update_config(self, action) -> None: """Update simulation configuration with new action parameters. This method applies new control parameters to the TORAX configuration for the next simulation step. Args: action: Action dictionary containing new parameter values. Must match the format expected by the ConfigLoader. Raises: ValueError: If action format is invalid or configuration update fails. """ try: # Apply action parameters to configuration with time constraints self.config.update_config( action, self.t_current, # Start time for this step self.delta_t_a, ) # Action timestep duration except ValueError as e: raise ValueError(f"Error updating configuration: {e}") # Rebuild geometry provider with updated configuration # This handles geometry changes (e.g., current profile modifications) self.geometry_provider = self.config.config_torax.geometry.build_provider # Rebuild dynamic runtime parameters provider with new configuration # This ensures time-dependent parameters reflect the updated config self.dynamic_runtime_params_slice_provider = ( build_runtime_params.DynamicRuntimeParamsSliceProvider.from_config( self.config.config_torax ) )
[docs] def get_output_datatree(self, start: int = 0, end: int = -1) -> DataTree: """Return the full simulation history as an xarray DataTree. This method reconstructs the complete trajectory of the simulation, including all state and post-processed output snapshots, as an xarray DataTree suitable for analysis and visualization. If `beginning` and `end` are specified, only data between those time values (inclusive) will be selected for all datasets in the DataTree that have a 'time' coordinate. Requires that the ToraxApp was initialized with `store_history=True` so that the full history is available. Args: start (int or float): Start time for selection. Defaults to ``0``. end (int or float): End time for selection. Defaults to ``-1`` (no upper limit). Returns: xarray.DataTree: The complete simulation history as an xarray DataTree, with all timesteps and outputs, or only the selected time range if specified. Raises: RuntimeError: If ``store_history`` was not enabled and thus no history is available. """ if self.store_history is False: raise RuntimeError() state_history = [output[0] for output in self.history_list] post_processed_outputs_history = [output[1] for output in self.history_list] state_history = output.StateHistory( state_history=state_history[start:end], post_processed_outputs_history=post_processed_outputs_history[start:end], sim_error=SimError.NO_ERROR, torax_config=self.config.config_torax, ) dt = state_history.simulation_output_to_xr(self.config.config_torax.restart) return dt
[docs] def save_output_file(self, file_name): """Save complete simulation history to NetCDF file. This method saves the full simulation trajectory to a NetCDF file suitable for analysis and visualization. Requires ``store_history=True`` in constructor. Args: file_name (str): Output file path with .nc extension Raises: RuntimeError: If ``store_history=False`` (no history to save) ValueError: If file writing fails """ if self.store_history is False: raise RuntimeError() state_history = [output[0] for output in self.history_list] post_processed_outputs_history = [output[1] for output in self.history_list] state_history = output.StateHistory( state_history=state_history, post_processed_outputs_history=post_processed_outputs_history, sim_error=SimError.NO_ERROR, torax_config=self.config.config_torax, ) dt = state_history.simulation_output_to_xr(self.config.config_torax.restart) try: dt.to_netcdf(file_name, engine="h5netcdf", mode="w") except Exception as e: raise ValueError(f"An error occurred while saving: {e}")
[docs] def get_state_data(self): """Get current simulation state as xarray DataTree. This method returns the current simulation state in xarray format, suitable for observation extraction and analysis. Returns: xarray.DataTree: Current simulation state. Raises: RuntimeError: If simulation state has not been computed yet. Note: - Returns single-timestep state (current moment) - For full history, use ``save_output_file()`` with ``store_history=True`` """ if self.state is None: raise RuntimeError("Simulation state has not been computed yet.") data = self.state.simulation_output_to_xr() return data