Source code for infralib.envs.base

r"""Base infrastructure management environment for gymnasium and stable-baselines3.

This module provides a base class for infrastructure maintenance environments that:

- Is fully compatible with gymnasium and stable-baselines3
- Uses the updated Simulator with model dependency system
- Provides abstract methods for model creation (no models defined in base)
- Supports all observability modes and reward schemes
- Includes comprehensive action and observation space handling

The base class assumes all inheriting environments will define appropriate
models (dynamics, cost, budget, hierarchy, metadata) as needed.

Example
-------
Creating a custom environment by inheriting from BaseInfraEnv::

    class MyCustomEnv(BaseInfraEnv):
        def _create_models(self):
            dynamics = MyDynamicsModel()
            cost = MyCostModel()
            budget = MyBudgetModel()
            return dynamics, cost, budget, None, None

        def _compute_reward(self, sim_info):
            return -sim_info['total_cost']

Classes
-------
BaseInfraEnv : Abstract base class for infrastructure environments
"""

from abc import ABC, abstractmethod
from typing import Any

import gymnasium as gym
import numpy as np

from ..models.budget import BudgetModel
from ..models.cost import CostModel
from ..models.dynamics import DynamicsModel
from ..models.hierarchy import HierarchyModel
from ..models.metadata import MetadataModel
from ..simulator import Simulator


[docs] class BaseInfraEnv(gym.Env, ABC): """Abstract base class for infrastructure maintenance environments. This base class provides the core functionality for infrastructure maintenance RL environments while requiring subclasses to define their own models and reward functions. It is fully compatible with gymnasium and stable-baselines3. The environment handles: - Action and observation space definition - Simulator integration with model dependencies - Episode management (reset, step, termination) - Multiple observability modes (full, partial, noisy) - Rendering capabilities Subclasses must implement: - _create_models(): Define dynamics, cost, budget, and optional hierarchy/metadata - _compute_reward(): Define reward function based on simulation info Parameters ---------- n_components : int Number of infrastructure components to simulate max_steps : int, default 365 Maximum number of steps per episode observability : {'full', 'partial', 'noisy'}, default 'full' Type of state observability action_type : {'multi_discrete', 'discrete', 'box'}, default 'multi_discrete' Format of action space render_mode : str, optional Rendering mode ('human', 'rgb_array', None) rich_display : bool, default False Enable rich terminal displays during simulation seed : int, optional Random seed for reproducibility Attributes ---------- n_components : int Number of components in the system simulator : Simulator Infrastructure simulator instance current_step : int Current episode step counter action_space : gym.Space Gymnasium action space observation_space : gym.Space Gymnasium observation space Notes ----- This class is designed to work seamlessly with stable-baselines3 and other modern RL libraries. All environments created by inheriting from this class will pass gymnasium's env_checker. The action space supports multiple formats: - 'multi_discrete': Separate action per component [4, 4, ..., 4] - 'discrete': Single action encoding all components (4^n_components) - 'box': Continuous actions (for advanced use cases) Examples -------- >>> class SimpleEnv(BaseInfraEnv): ... def _create_models(self): ... dynamics = MarkovDynamics(n_states=10) ... cost = SimpleCost() ... budget = FixedBudget(initial_budget=5000) ... return dynamics, cost, budget, None, None ... ... def _compute_reward(self, sim_info): ... return -sim_info['total_cost'] - sim_info['failures'] * 100 ... >>> env = SimpleEnv(n_components=5) >>> obs, info = env.reset() >>> action = env.action_space.sample() >>> obs, reward, terminated, truncated, info = env.step(action) """ metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
[docs] def __init__( self, n_components: int, max_steps: int = 365, observability: str = "full", action_type: str = "multi_discrete", render_mode: str | None = None, rich_display: bool = False, seed: int | None = None, ): super().__init__() # Environment configuration self.n_components = n_components self.max_steps = max_steps self.observability = observability self.action_type = action_type self.render_mode = render_mode self.rich_display = rich_display self.seed = seed # Environment state self.current_step = 0 self.terminated = False self.truncated = False # Create models (must be implemented by subclasses) dynamics, cost, budget, hierarchy, metadata = self._create_models() # Create simulator with models self.simulator = Simulator( dynamics=dynamics, cost=cost, budget=budget, hierarchy=hierarchy, metadata=metadata, rich_display=rich_display, seed=seed, ) # Set failure_conditions in simulator if available (for visualization compatibility) if hasattr(self, "failure_thresholds"): self.simulator.failure_conditions = self.failure_thresholds # Define action and observation spaces self._define_spaces() # Rendering self.render_history = []
@abstractmethod def _create_models( self, ) -> tuple[ DynamicsModel, CostModel, BudgetModel, HierarchyModel | None, MetadataModel | None, ]: """Create and return models for the environment. This method must be implemented by subclasses to define the specific models used by the environment. The base class makes no assumptions about which models to use. Returns ------- tuple (dynamics, cost, budget, hierarchy, metadata) where hierarchy and metadata can be None if not needed Examples -------- >>> def _create_models(self): ... dynamics = MarkovDynamics(n_states=10) ... cost = SimpleCost(base_repair_cost=100) ... budget = FixedBudget(initial_budget=5000) ... return dynamics, cost, budget, None, None """ pass @abstractmethod def _compute_reward(self, sim_info: dict[str, Any]) -> float: """Compute reward based on simulation step information. This method must be implemented by subclasses to define the reward function. It receives the info dictionary from the simulator step. Parameters ---------- sim_info : dict Information dictionary from simulator.step() containing: - 'total_cost': Cost of actions taken - 'failures': Number of failed components - 'budget_remaining': Remaining budget - 'mean_condition': Average component condition - And other simulation metrics Returns ------- float Scalar reward value Examples -------- >>> def _compute_reward(self, sim_info): ... cost_penalty = sim_info['total_cost'] / 1000.0 ... failure_penalty = sim_info['failures'] * 10.0 ... condition_reward = sim_info['mean_condition'] / 10.0 ... return condition_reward - cost_penalty - failure_penalty """ pass def _define_spaces(self): """Define gymnasium action and observation spaces.""" # Action space if self.action_type == "multi_discrete": # Separate discrete action for each component: [0,1,2,3] per component self.action_space = gym.spaces.MultiDiscrete([4] * self.n_components) elif self.action_type == "discrete": # Single discrete action encoding all components: 4^n_components possibilities self.action_space = gym.spaces.Discrete(4**self.n_components) elif self.action_type == "box": # Continuous actions (can be useful for some algorithms) self.action_space = gym.spaces.Box( low=0, high=3, shape=(self.n_components,), dtype=np.float32 ) else: raise ValueError(f"Unknown action_type: {self.action_type}") # Observation space - get dimension from simulator # Create temporary observation to determine dimensions temp_states = np.ones(self.n_components) * 5 # Mid-range states self.simulator.reset(self.n_components, temp_states) sample_obs = self.simulator.get_observation(self.observability) obs_dim = len(sample_obs) # Define observation bounds if self.observability in ["full", "partial", "noisy"]: # States normalized to [0,1], time normalized, budget normalized low = np.full(obs_dim, -1.0, dtype=np.float32) high = np.full(obs_dim, 1.0, dtype=np.float32) else: low = np.full(obs_dim, -np.inf, dtype=np.float32) high = np.full(obs_dim, np.inf, dtype=np.float32) self.observation_space = gym.spaces.Box( low=low, high=high, shape=(obs_dim,), dtype=np.float32 )
[docs] def reset( self, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[np.ndarray, dict[str, Any]]: """Reset the environment to start a new episode. Parameters ---------- seed : int, optional Random seed for the episode options : dict, optional Additional options including 'initial_states' Returns ------- tuple (observation, info) where observation is the initial state observation and info contains environment metadata """ super().reset(seed=seed) if seed is not None: np.random.seed(seed) # Reset simulator initial_states = None if options and "initial_states" in options: initial_states = options["initial_states"] self.simulator.reset(self.n_components, initial_states) # Reset environment state self.current_step = 0 self.terminated = False self.truncated = False self.render_history = [] # Get initial observation and info observation = self._get_observation() info = self._get_info() return observation.astype(np.float32), info
[docs] def step( self, action: int | np.ndarray ) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: """Take a step in the environment. Parameters ---------- action : int or np.ndarray Action to take, format depends on action_type Returns ------- tuple (observation, reward, terminated, truncated, info) Raises ------ RuntimeError If called on terminated/truncated environment ValueError If action format is invalid """ if self.terminated or self.truncated: raise RuntimeError("Cannot call step() on terminated/truncated environment") # Convert action to numpy array format expected by simulator action_array = self._process_action(action) # Take step in simulator states, sim_info = self.simulator.step(action_array) # Compute reward reward = self._compute_reward(sim_info) # Check termination conditions self.terminated, self.truncated = self._check_termination(sim_info) # Increment step counter self.current_step += 1 # Get observation and info observation = self._get_observation() info = self._get_info() info.update(sim_info) # Add simulator info # Store for rendering if self.render_mode is not None: self.render_history.append( { "step": self.current_step, "states": states.copy(), "actions": action_array.copy(), "reward": reward, "info": sim_info.copy(), } ) return ( observation.astype(np.float32), float(reward), self.terminated, self.truncated, info, )
def _process_action(self, action: int | np.ndarray) -> np.ndarray: """Convert action from various formats to numpy array.""" if self.action_type == "multi_discrete": # Action is already array-like action_array = np.asarray(action, dtype=np.int32) if len(action_array) != self.n_components: raise ValueError( f"Multi-discrete action length {len(action_array)} != n_components {self.n_components}" ) elif self.action_type == "discrete": # Decode single action index to per-component actions action_idx = int(action) action_array = [] for _ in range(self.n_components): action_array.append(action_idx % 4) action_idx //= 4 action_array = np.array(action_array[::-1], dtype=np.int32) # Reverse order elif self.action_type == "box": # Round continuous actions to discrete action_array = np.round(np.clip(action, 0, 3)).astype(np.int32) else: raise ValueError(f"Unknown action_type: {self.action_type}") return action_array def _get_observation(self) -> np.ndarray: """Get current observation from simulator.""" return self.simulator.get_observation(self.observability) def _get_info(self) -> dict[str, Any]: """Get environment info dictionary.""" return { "current_step": self.current_step, "max_steps": self.max_steps, "n_components": self.n_components, "observability": self.observability, "action_type": self.action_type, } def _check_termination(self, sim_info: dict[str, Any]) -> tuple[bool, bool]: """Check episode termination conditions. Default implementation terminates on budget exhaustion or too many failures, and truncates on max steps. Can be overridden by subclasses. Parameters ---------- sim_info : dict Simulation step information Returns ------- tuple (terminated, truncated) """ terminated = False truncated = False # Terminated if budget exhausted if sim_info.get("budget_remaining", 0) <= 0: terminated = True # Terminated if too many components failed (>50%) failure_threshold = self.n_components * 0.5 if sim_info.get("failures", 0) > failure_threshold: terminated = True # Truncated if max steps reached (check after current step would be incremented) if self.current_step + 1 >= self.max_steps: truncated = True return terminated, truncated
[docs] def render(self) -> np.ndarray | str | None: """Render the environment state. Returns ------- np.ndarray or str or None Rendered output depending on render_mode """ if self.render_mode == "human": if self.rich_display and self.simulator.console: # Use rich display if available last_info = ( self.render_history[-1]["info"] if self.render_history else {} ) self.simulator.display_status(last_info) else: # Simple text output print(f"Step {self.current_step}:") print(f" States: {self.simulator.states}") budget_available = ( self.simulator.budget.available() if hasattr(self.simulator.budget, "available") else self.simulator.budget._available_internal() ) print(f" Budget: {budget_available:.0f}") print(f" Failures: {np.sum(self.simulator.states == 0)}") return None elif self.render_mode == "rgb_array": return self._render_rgb_array() else: return None
def _render_rgb_array(self) -> np.ndarray: """Create RGB array visualization of component states.""" height, width = 64, 64 # Create grid layout for components grid_size = int(np.ceil(np.sqrt(self.n_components))) component_size = min(height // grid_size, width // grid_size) rgb_array = np.zeros((height, width, 3), dtype=np.uint8) for i, state in enumerate(self.simulator.states): row = i // grid_size col = i % grid_size y_start = row * component_size y_end = min((row + 1) * component_size, height) x_start = col * component_size x_end = min((col + 1) * component_size, width) # Color based on state: red=failed, yellow=poor, green=good if state == 0: color = [255, 0, 0] # Red - failed elif state < 3: color = [255, 0, 0] # Red - critical elif state < 5: color = [255, 165, 0] # Orange - poor elif state < 7: color = [255, 255, 0] # Yellow - fair else: color = [0, 255, 0] # Green - good rgb_array[y_start:y_end, x_start:x_end] = color return rgb_array
[docs] def close(self): """Clean up environment resources.""" if hasattr(self.simulator, "close"): self.simulator.close() self.render_history = []
[docs] def make_env_from_config(env_class, config_path: str, **kwargs) -> BaseInfraEnv: """Create environment from configuration file. Parameters ---------- env_class : class Environment class that inherits from BaseInfraEnv config_path : str Path to YAML configuration file **kwargs Additional keyword arguments to override config Returns ------- BaseInfraEnv Configured environment instance """ import yaml with open(config_path) as f: config = yaml.safe_load(f) # Extract environment configuration env_config = config.get("environment", {}) env_config.update(kwargs) # Override with kwargs return env_class(**env_config)