"""TORAX plotting helper functions for visualization.
This module provides utilities for creating matplotlib figures and updating plots
with TORAX simulation data. The functions are designed to work with TORAX
plotting system while supporting both static image generation and real-time
visualization updates.
Key functions:
- `create_figure()`: Sets up matplotlib figure with TORAX styling and font scaling
- `update_lines()`: Updates plot lines with simulation data (spatial profiles or time series)
- `validate_plotdata()`: Ensures plot configuration matches available data attributes
- `load_data()`: Processes TORAX `DataTree` output into `PlotData` format with unit conversions
All of these functions are adapted from TORAX ``plotruns_lib`` module, with modifications
to be able to apply them in the GymTORAX environments.
"""
import inspect
import logging
import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from torax._src.output_tools import output
from torax._src.plotting import plotruns_lib
# Set up logger for this module
logger = logging.getLogger(__name__)
# Font scaling constants
FONT_SCALE_BASE = 1.0 # Base scaling factor
FONT_SCALE_PER_ROW = 0.3 # Additional scaling per row
[docs]
def update_lines(lines, axes, plot_config, plotdata, t, first_update):
"""Update or create plot lines with simulation data.
As side effects, this function sets ``cfg.include_first_timepoint = True``
on each axis config, and for `TIME_SERIES` on subsequent updates, appends
data to existing line coordinates.
Args:
lines (list): Existing matplotlib `Line2D` objects. Empty on first call.
axes (list): Matplotlib axes objects matching `plot_config` layout.
plot_config (plotruns_lib.FigureProperties): Defines subplot configurations,
each with `plot_type`, `attrs` (variable names), `labels`, and `colors`.
plotdata (plotruns_lib.PlotData): Simulation data with plasma variables.
t (float): Current simulation time (used for `TIME_SERIES` updates).
first_update (bool): If `True`, creates new lines; if `False`, updates existing.
Returns:
list: Updated list of `Line2D` objects for future calls.
Raises:
ValueError: If `plot_type` is not `SPATIAL` or `TIME_SERIES`.
Note:
Uses ``plotruns_lib.get_rho()`` to determine x-coordinate for spatial plots.
Color cycling follows ``plot_config.colors`` list with modulo indexing.
"""
line_idx = 0
for ax, cfg in zip(axes, plot_config.axes):
line_idx_color = 0
cfg.include_first_timepoint = True # I don't know why, but it is needed...
if cfg.plot_type == plotruns_lib.PlotType.SPATIAL:
for attr, label in zip(cfg.attrs, cfg.labels):
data = getattr(plotdata, attr)
# if cfg.suppress_zero_values and np.all(data == 0):
# continue
rho = plotruns_lib.get_rho(plotdata, attr)
if first_update is True:
(line,) = ax.plot(
rho,
data[0, :],
plot_config.colors[line_idx_color % len(plot_config.colors)],
label=label,
)
lines.append(line)
line_idx_color += 1
else:
lines[line_idx].set_xdata(rho)
lines[line_idx].set_ydata(data[0, :])
line_idx += 1
elif cfg.plot_type == plotruns_lib.PlotType.TIME_SERIES:
for attr, label in zip(cfg.attrs, cfg.labels):
data = getattr(plotdata, attr)
if first_update is True:
# if cfg.suppress_zero_values and np.all(data == 0):
# continue
# EXACT same logic as get_lines() - plot entire time series
(line,) = ax.plot(
plotdata.t,
data, # Plot entire time series (same as get_lines)
plot_config.colors[line_idx_color % len(plot_config.colors)],
label=label,
)
lines.append(line)
line_idx_color += 1
else:
xdata = lines[line_idx].get_xdata()
ydata = lines[line_idx].get_ydata()
lines[line_idx].set_xdata(np.append(xdata, t))
lines[line_idx].set_ydata(np.append(ydata, data))
line_idx += 1
else:
raise ValueError(f"Unknown plot type: {cfg.plot_type}")
return lines
[docs]
def validate_plotdata(
plotdata: plotruns_lib.PlotData, plot_config: plotruns_lib.FigureProperties
):
"""Check that all plot configuration attributes exist in plotdata.
Uses introspection to find all available attributes in the `PlotData` object
(both dataclass fields and properties), then verifies that every attribute
name listed in the plot configuration ``axes.attrs`` lists exists.
Args:
plotdata (plotruns_lib.PlotData): Data object to check.
plot_config (plotruns_lib.FigureProperties): Plot configuration with
axes definitions. Each axis config has an ``attrs`` list of variable names.
Raises:
ValueError: If any attribute in ``plot_config.axes[*].attrs`` is not found
in `plotdata`. Error message identifies the missing attribute name.
"""
# EXACT same attribute validation as plot_run()
plotdata_fields = set(plotdata.__dataclass_fields__)
plotdata_properties = {
name
for name, _ in inspect.getmembers(
type(plotdata), lambda o: isinstance(o, property)
)
}
plotdata_attrs = plotdata_fields.union(plotdata_properties)
for cfg in plot_config.axes:
for attr in cfg.attrs:
if attr not in plotdata_attrs:
raise ValueError(
f"Attribute '{attr}' in plot_config does not exist in PlotData"
)
[docs]
def load_data(data_tree: xr.DataTree) -> plotruns_lib.PlotData:
r"""Convert TORAX DataTree output to PlotData with unit transformations.
Extracts time coordinate and applies unit conversions to match TORAX plotting
conventions (A/m² → MA/m², W → MW, m⁻³ → 10²⁰ m⁻³, etc.). Handles hierarchical
`DataTree` structure by extracting from ``profiles/`` and ``scalars/`` branches.
Args:
data_tree (xarray.DataTree): TORAX simulation output.
Returns:
plotruns_lib.PlotData: Object with plasma variables in plotting units.
"""
# Handle potential time coordinate name variations
time = data_tree[output.TIME].to_numpy()
def get_optional_data(ds, key, grid_type):
if grid_type.lower() not in ["cell", "face"]:
raise ValueError(
f'grid_type for {key} must be either "cell" or "face", got {grid_type}'
)
if key in ds:
return ds[key].to_numpy()
else:
return (
np.zeros((len(time), len(ds[output.RHO_CELL_NORM])))
if grid_type == "cell"
else np.zeros((len(time), len(ds[output.RHO_FACE_NORM].to_numpy())))
)
def _transform_data(ds: xr.Dataset):
"""Transforms data in-place to the desired units."""
# TODO(b/414755419)
ds = ds.copy()
transformations = {
output.J_TOTAL: 1e6, # A/m^2 to MA/m^2
output.J_OHMIC: 1e6, # A/m^2 to MA/m^2
output.J_BOOTSTRAP: 1e6, # A/m^2 to MA/m^2
output.J_EXTERNAL: 1e6, # A/m^2 to MA/m^2
"j_generic_current": 1e6, # A/m^2 to MA/m^2
output.I_BOOTSTRAP: 1e6, # A to MA
output.IP_PROFILE: 1e6, # A to MA
"j_ecrh": 1e6, # A/m^2 to MA/m^2
"p_icrh_i": 1e6, # W/m^3 to MW/m^3
"p_icrh_e": 1e6, # W/m^3 to MW/m^3
"p_generic_heat_i": 1e6, # W/m^3 to MW/m^3
"p_generic_heat_e": 1e6, # W/m^3 to MW/m^3
"p_ecrh_e": 1e6, # W/m^3 to MW/m^3
"p_alpha_i": 1e6, # W/m^3 to MW/m^3
"p_alpha_e": 1e6, # W/m^3 to MW/m^3
"p_ohmic_e": 1e6, # W/m^3 to MW/m^3
"p_bremsstrahlung_e": 1e6, # W/m^3 to MW/m^3
"p_cyclotron_radiation_e": 1e6, # W/m^3 to MW/m^3
"p_impurity_radiation_e": 1e6, # W/m^3 to MW/m^3
"ei_exchange": 1e6, # W/m^3 to MW/m^3
"P_ohmic_e": 1e6, # W to MW
"P_aux_total": 1e6, # W to MW
"P_alpha_total": 1e6, # W to MW
"P_bremsstrahlung_e": 1e6, # W to MW
"P_cyclotron_e": 1e6, # W to MW
"P_ecrh": 1e6, # W to MW
"P_radiation_e": 1e6, # W to MW
"I_ecrh": 1e6, # A to MA
"I_aux_generic": 1e6, # A to MA
"W_thermal_total": 1e6, # J to MJ
output.N_E: 1e20, # m^-3 to 10^{20} m^-3
output.N_I: 1e20, # m^-3 to 10^{20} m^-3
output.N_IMPURITY: 1e20, # m^-3 to 10^{20} m^-3
}
for var_name, scale in transformations.items():
if var_name in ds:
ds[var_name] /= scale
return ds
data_tree = xr.map_over_datasets(_transform_data, data_tree)
profiles_dataset = data_tree.children[output.PROFILES].dataset
scalars_dataset = data_tree.children[output.SCALARS].dataset
dataset = data_tree.dataset
return plotruns_lib.PlotData(
T_i=profiles_dataset[output.T_I].to_numpy(),
T_e=profiles_dataset[output.T_E].to_numpy(),
n_e=profiles_dataset[output.N_E].to_numpy(),
n_i=profiles_dataset[output.N_I].to_numpy(),
n_impurity=profiles_dataset[output.N_IMPURITY].to_numpy(),
Z_impurity=profiles_dataset[output.Z_IMPURITY].to_numpy(),
psi=profiles_dataset[output.PSI].to_numpy(),
v_loop=profiles_dataset[output.V_LOOP].to_numpy(),
j_total=profiles_dataset[output.J_TOTAL].to_numpy(),
j_ohmic=profiles_dataset[output.J_OHMIC].to_numpy(),
j_bootstrap=profiles_dataset[output.J_BOOTSTRAP].to_numpy(),
j_external=profiles_dataset[output.J_EXTERNAL].to_numpy(),
j_ecrh=get_optional_data(profiles_dataset, "j_ecrh", "cell"),
j_generic_current=get_optional_data(
profiles_dataset, "j_generic_current", "cell"
),
q=profiles_dataset[output.Q].to_numpy(),
magnetic_shear=profiles_dataset[output.MAGNETIC_SHEAR].to_numpy(),
chi_turb_i=profiles_dataset[output.CHI_TURB_I].to_numpy(),
chi_turb_e=profiles_dataset[output.CHI_TURB_E].to_numpy(),
D_turb_e=profiles_dataset[output.D_TURB_E].to_numpy(),
V_turb_e=profiles_dataset[output.V_TURB_E].to_numpy(),
rho_norm=dataset[output.RHO_NORM].to_numpy(),
rho_cell_norm=dataset[output.RHO_CELL_NORM].to_numpy(),
rho_face_norm=dataset[output.RHO_FACE_NORM].to_numpy(),
p_icrh_i=get_optional_data(profiles_dataset, "p_icrh_i", "cell"),
p_icrh_e=get_optional_data(profiles_dataset, "p_icrh_e", "cell"),
p_generic_heat_i=get_optional_data(
profiles_dataset, "p_generic_heat_i", "cell"
),
p_generic_heat_e=get_optional_data(
profiles_dataset, "p_generic_heat_e", "cell"
),
p_ecrh_e=get_optional_data(profiles_dataset, "p_ecrh_e", "cell"),
p_alpha_i=get_optional_data(profiles_dataset, "p_alpha_i", "cell"),
p_alpha_e=get_optional_data(profiles_dataset, "p_alpha_e", "cell"),
p_ohmic_e=get_optional_data(profiles_dataset, "p_ohmic_e", "cell"),
p_bremsstrahlung_e=get_optional_data(
profiles_dataset, "p_bremsstrahlung_e", "cell"
),
p_cyclotron_radiation_e=get_optional_data(
profiles_dataset, "p_cyclotron_radiation_e", "cell"
),
p_impurity_radiation_e=get_optional_data(
profiles_dataset, "p_impurity_radiation_e", "cell"
),
ei_exchange=profiles_dataset["ei_exchange"].to_numpy(), # ion heating/sink
Q_fusion=scalars_dataset["Q_fusion"].to_numpy(), # pylint: disable=invalid-name
s_gas_puff=get_optional_data(profiles_dataset, "s_gas_puff", "cell"),
s_generic_particle=get_optional_data(
profiles_dataset, "s_generic_particle", "cell"
),
s_pellet=get_optional_data(profiles_dataset, "s_pellet", "cell"),
Ip_profile=profiles_dataset[output.IP_PROFILE].to_numpy()[:, -1],
I_bootstrap=scalars_dataset[output.I_BOOTSTRAP].to_numpy(),
I_aux_generic=scalars_dataset["I_aux_generic"].to_numpy(),
I_ecrh=scalars_dataset["I_ecrh"].to_numpy(),
P_ohmic_e=scalars_dataset["P_ohmic_e"].to_numpy(),
P_auxiliary=scalars_dataset["P_aux_total"].to_numpy(),
P_alpha_total=scalars_dataset["P_alpha_total"].to_numpy(),
P_sink=scalars_dataset["P_bremsstrahlung_e"].to_numpy()
+ scalars_dataset["P_radiation_e"].to_numpy()
+ scalars_dataset["P_cyclotron_e"].to_numpy(),
P_bremsstrahlung_e=scalars_dataset["P_bremsstrahlung_e"].to_numpy(),
P_radiation_e=scalars_dataset["P_radiation_e"].to_numpy(),
P_cyclotron_e=scalars_dataset["P_cyclotron_e"].to_numpy(),
T_e_volume_avg=scalars_dataset["T_e_volume_avg"].to_numpy(),
T_i_volume_avg=scalars_dataset["T_i_volume_avg"].to_numpy(),
n_e_volume_avg=scalars_dataset["n_e_volume_avg"].to_numpy(),
n_i_volume_avg=scalars_dataset["n_i_volume_avg"].to_numpy(),
W_thermal_total=scalars_dataset["W_thermal_total"].to_numpy(),
q95=scalars_dataset["q95"].to_numpy(),
t=time,
)