Source code for infralib.visualize.simulator_plots

"""Matplotlib/Seaborn visualization utilities for infrastructure simulators."""

from typing import Any

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.ticker import MaxNLocator


[docs] def set_plot_style( style: str = "whitegrid", context: str = "paper", font_scale: float = 1.5 ): """Set consistent plot style for publication-quality figures. Parameters ---------- style : str Seaborn style preset context : str Seaborn context for scaling font_scale : float Font scaling factor """ sns.set_style(style) sns.set_context(context, font_scale=font_scale) plt.rcParams.update( { "figure.dpi": 100, "savefig.dpi": 300, "axes.spines.top": False, "axes.spines.right": False, } )
[docs] def plot_state_budget_history( simulator: Any, num_steps: int, save_path: str | None = None, show: bool = True, figsize: tuple[int, int] = (14, 15), ) -> plt.Figure: """ Plot the distribution of component states over time, budget changes, and action counts. This visualization provides a comprehensive view of the simulation history including: - Distribution of component states (violin plot) - Budget evolution over time - Action counts per timestep Parameters ---------- simulator : Simulator The simulator instance after simulation with state_history, budget_history, and action_history attributes num_steps : int Number of time steps in the simulation save_path : str, optional Path to save the figure show : bool Whether to display the plot figsize : tuple Figure size (width, height) Returns ------- fig : matplotlib.figure.Figure The generated figure """ sns.set_theme(style="whitegrid") plt.rcParams.update({"figure.dpi": 100}) # Extract data from simulator state_history_array = np.array(simulator.history["states"]) time_steps = np.arange(len(state_history_array)) # Create figure with three subplots fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=figsize, sharex=True) # --- Top plot: State distribution --- num_components = state_history_array.shape[1] data = { "Time Step": np.repeat(time_steps, num_components), "State": state_history_array.flatten(), } df = pd.DataFrame(data) sns.violinplot( x="Time Step", y="State", data=df, inner="quartile", ax=ax1, density_norm="width", cut=0, ) ax1.set_title("Distribution of Component States Over Time", fontsize=16, pad=20) ax1.set_ylabel("State", fontsize=14) ax1.grid(True, alpha=0.3) ax1.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False) # --- Middle plot: Budget changes --- budget_changes = np.array(simulator.history["budget_remaining"]) ax2.plot( time_steps, budget_changes, marker="o", color="darkred", alpha=0.7, linewidth=2, markersize=5, ) ax2.fill_between(time_steps, budget_changes, alpha=0.2, color="red") ax2.set_title("Budget Changes Over Time", fontsize=16, pad=20) ax2.set_ylabel("Budget", fontsize=14) ax2.grid(True, alpha=0.3) ax2.yaxis.set_major_formatter(mtick.StrMethodFormatter("{x:,.0f}")) ax2.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False) # --- Bottom plot: Action counts --- # Pad actions with zeros at the beginning since first state has no actions if simulator.history["actions"]: action_history = np.array(simulator.history["actions"]) # Add a zero action at the beginning to match states action_history = np.vstack([np.zeros_like(action_history[0]), action_history]) else: # No actions taken yet action_history = np.zeros((len(time_steps), simulator.n_components)) action_names = ["Do Nothing", "Inspect", "Repair", "Replace"] colors = ["blue", "green", "orange", "red"] markers = ["o", "s", "^", "D"] legend_handles = [] for action_type in range(4): action_counts = np.sum(action_history == action_type, axis=1) if np.any(action_counts > 0): # Ensure x and y have the same length plot_time_steps = time_steps[: len(action_counts)] scatter = ax3.scatter( plot_time_steps, action_counts, color=colors[action_type], marker=markers[action_type], alpha=0.7, s=50, ) legend_handles.append((scatter, action_names[action_type])) ax3.set_title("Action Counts Over Time", fontsize=16, pad=20) ax3.set_xlabel("Time Step", fontsize=14) ax3.set_ylabel("Number of Actions", fontsize=14) ax3.grid(True, alpha=0.3) # Create horizontal legend above the plot if legend_handles: legend_elements = [handle for handle, _ in legend_handles] legend_labels = [label for _, label in legend_handles] ax3.legend( legend_elements, legend_labels, bbox_to_anchor=(0.5, 1.25), loc="center", ncol=len(legend_handles), fontsize=12, ) ax3.xaxis.set_major_locator(MaxNLocator(integer=True, nbins=10)) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") if show: plt.show() return fig
[docs] def plot_component_states_comparison( state_histories: list[np.ndarray], action_histories: list[np.ndarray] | None = None, labels: list[str] | None = None, failure_conditions: np.ndarray | None = None, component_idx: int = 0, type_indices: np.ndarray | None = None, max_steps: int | None = None, save_path: str | None = None, show: bool = True, figsize: tuple[int, int] | None = None, ) -> plt.Figure: """ Plot and compare condition index over time for a specific component across different policies. This is useful for comparing how different control strategies affect individual component degradation and maintenance. Parameters ---------- state_histories : list of np.ndarray State history for each policy/strategy. Shape: (timesteps, n_components) action_histories : list of np.ndarray, optional Action history for each policy. Shape: (timesteps, n_components) labels : list of str, optional Labels for each policy/strategy failure_conditions : np.ndarray, optional Failure thresholds for each component type component_idx : int Index of component to visualize type_indices : np.ndarray, optional Mapping from component index to component type max_steps : int, optional Maximum number of steps to show save_path : str, optional Path to save the figure show : bool Whether to display the plot figsize : tuple, optional Figure size (width, height) Returns ------- fig : matplotlib.figure.Figure The generated figure """ set_plot_style() num_methods = len(state_histories) time_steps = np.arange(len(state_histories[0])) if max_steps: time_steps = time_steps[:max_steps] if labels is None: labels = [f"Method {i+1}" for i in range(num_methods)] if figsize is None: figsize = (12, 4 * num_methods) fig, axes = plt.subplots(num_methods, 1, figsize=figsize, sharex=True) if num_methods == 1: axes = [axes] # Color scheme for actions uiuc_orange = (255 / 255.0, 85 / 255.0, 46 / 255.0) action_colors = {0: "gray", 1: "green", 2: uiuc_orange, 3: "blue"} action_labels = {0: "No Action", 1: "Inspect", 2: "Repair", 3: "Replace"} for idx, ax in enumerate(axes): # Extract component history state_history = state_histories[idx][: len(time_steps), component_idx] label = labels[idx] # Plot condition index ax.plot( time_steps, state_history, ".-", color="black", linewidth=2, alpha=1.0, label="Condition Index", ) # Overlay actions if provided if action_histories is not None: action_history = action_histories[idx][: len(time_steps), component_idx] unique_actions = np.unique(action_history) for action in unique_actions: indices = np.where(action_history == action)[0] if len(indices) > 0: states = state_history[indices] ax.scatter( indices, states, color=action_colors.get(action, "black"), s=70, alpha=0.9, label=action_labels.get(action, f"Action {action}"), ) # Plot failure threshold if provided if failure_conditions is not None and type_indices is not None: failure_condition = failure_conditions[type_indices[component_idx]] ax.axhline( failure_condition, color="red", linestyle="--", linewidth=1.5, alpha=0.6, label="Failure Threshold", ) elif failure_conditions is not None: # Single failure threshold for all components ax.axhline( failure_conditions[0] if len(failure_conditions) > 0 else 2.0, color="red", linestyle="--", linewidth=1.5, alpha=0.6, label="Failure Threshold", ) ax.set_xlim(0, len(time_steps) - 1) ax.set_ylim(0, np.max(state_history) + 10) ax.set_title(label, fontsize=14) ax.set_ylabel("Condition Index", fontsize=12) ax.grid(True, alpha=0.3) if idx == num_methods - 1: ax.set_xlabel("Time Steps", fontsize=12) # Remove duplicate labels in legend handles, labels_legend = ax.get_legend_handles_labels() by_label = dict(zip(labels_legend, handles, strict=False)) ax.legend(by_label.values(), by_label.keys(), loc="best", fontsize=10) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") if show: plt.show() return fig
[docs] def plot_action_distribution( action_history: np.ndarray, save_path: str | None = None, show: bool = True, figsize: tuple[int, int] = (10, 6), ) -> plt.Figure: """ Plot the distribution of actions taken during simulation. Parameters ---------- action_history : np.ndarray Action history array. Shape: (timesteps, n_components) save_path : str, optional Path to save the figure show : bool Whether to display the plot figsize : tuple Figure size (width, height) Returns ------- fig : matplotlib.figure.Figure The generated figure """ set_plot_style() fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) action_names = ["Do Nothing", "Inspect", "Repair", "Replace"] colors = ["blue", "green", "orange", "red"] # Count total actions action_counts = np.bincount(action_history.flatten(), minlength=4) # Pie chart non_zero_actions = action_counts > 0 ax1.pie( action_counts[non_zero_actions], labels=[action_names[i] for i in range(4) if non_zero_actions[i]], colors=[colors[i] for i in range(4) if non_zero_actions[i]], autopct="%1.1f%%", startangle=90, ) ax1.set_title("Overall Action Distribution") # Actions over time time_steps = np.arange(action_history.shape[0]) for action_type in range(4): action_counts_time = np.sum(action_history == action_type, axis=1) if np.any(action_counts_time > 0): ax2.plot( time_steps, action_counts_time, label=action_names[action_type], color=colors[action_type], alpha=0.7, linewidth=2, ) ax2.set_title("Action Counts Over Time") ax2.set_xlabel("Time Step") ax2.set_ylabel("Number of Actions") ax2.legend(loc="best") ax2.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") if show: plt.show() return fig