"""Unified base model architecture for InfraLib."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
import numpy as np
[docs]
@dataclass
class ModelContext:
"""Context object passed to all models containing state and dependencies."""
states: np.ndarray
actions: np.ndarray | None = None
time_step: int = 0
dynamics: Any = None
cost: Any = None
budget: Any = None
hierarchy: Any = None
metadata: Any = None
component_ids: np.ndarray | None = None
history: dict[str, list] | None = None
custom_data: dict[str, Any] | None = None
[docs]
def get_model(self, model_type: str) -> Any:
"""Get a model by type name."""
return getattr(self, model_type, None)
[docs]
class BaseModel(ABC):
"""Unified base class for all infrastructure models."""
[docs]
def __init__(self, **params):
"""Initialize with parameters and validate them."""
self.params = params
self._validate_init_params()
self._setup()
def _validate_init_params(self):
"""Validate initialization parameters."""
spec = self.get_parameter_spec()
for name, value in self.params.items():
if name in spec:
param_type, (min_val, max_val), _ = spec[name]
if isinstance(value, (list, np.ndarray)):
if not all(min_val <= v <= max_val for v in value):
raise ValueError(
f"{name} values must be between {min_val} and {max_val}"
)
else:
if not min_val <= value <= max_val:
raise ValueError(
f"{name} must be between {min_val} and {max_val}"
)
@abstractmethod
def _setup(self):
"""Setup model after parameter validation.
Called automatically after __init__ parameter validation.
Use this to initialize internal state, build matrices, etc.
"""
pass
[docs]
@abstractmethod
def compute(self, context: ModelContext) -> Any:
"""Main computation method - unified across all models.
Args:
context: ModelContext with current state and dependencies
Returns:
Model-specific output
"""
pass
[docs]
@abstractmethod
def reset(self, context: ModelContext | None = None):
"""Reset model to initial state.
Args:
context: Optional context for state-dependent resets
"""
pass
[docs]
@classmethod
@abstractmethod
def get_parameter_spec(cls) -> dict[str, tuple[type, tuple[float, float], str]]:
"""Get parameter specifications.
Returns:
Dict of param_name -> (type, (min, max), description)
"""
pass
[docs]
@classmethod
def get_required_models(cls) -> list[str]:
"""List of required model dependencies.
Override to specify which other models this one needs.
Returns list of model names: ['dynamics', 'cost', 'hierarchy', etc.]
"""
return []
[docs]
def validate_context(self, context: ModelContext):
"""Validate that context has required dependencies."""
for model_name in self.get_required_models():
if context.get_model(model_name) is None:
raise ValueError(f"Required model '{model_name}' not found in context")