Source code for infralib.models.hierarchy

"""Hierarchy system with unified interface."""

from dataclasses import dataclass, field
from typing import Any

import numpy as np

from .base import BaseModel, ModelContext


[docs] @dataclass class HierarchyLevel: """Definition of a hierarchy level.""" name: str parent_level: str | None = None properties: dict[str, Any] = field(default_factory=dict) aggregation_rules: dict[str, str] = field(default_factory=dict)
[docs] class HierarchyModel(BaseModel): """Base class for hierarchy models with unified interface."""
[docs] def compute(self, context: ModelContext) -> dict[str, Any]: """Compute hierarchy-based metrics. Args: context: Contains states and component information Returns: Dict of hierarchy metrics and aggregations """ self.validate_context(context) return self._compute_hierarchy_metrics(context)
def _compute_hierarchy_metrics(self, context: ModelContext) -> dict[str, Any]: """Internal computation of hierarchy metrics.""" metrics = {} if context.states is None: return metrics for level in self.get_hierarchy_levels(): if level.name == "component": continue level_metrics = self._compute_level_metrics(level, context.states) if level_metrics: metrics[f"{level.name}_metrics"] = level_metrics return metrics def _compute_level_metrics( self, level: HierarchyLevel, states: np.ndarray ) -> dict[str, Any]: """Compute metrics for a specific hierarchy level.""" groups = self.get_all_groups(level.name) metrics = {} for group in groups: components = self.get_group_components(group, level.name) if components: component_states = [ states[cid] for cid in components if isinstance(cid, int) and cid < len(states) ] if component_states: metrics[group] = self._aggregate_states(component_states, level) return metrics def _aggregate_states( self, states: list, level: HierarchyLevel ) -> dict[str, float]: """Aggregate component states based on level rules.""" state_array = np.array(states) aggregation = { "mean": float(np.mean(state_array)), "min": float(np.min(state_array)), "max": float(np.max(state_array)), "failures": int(np.sum(state_array == 0)), } for field, rule in level.aggregation_rules.items(): if rule == "min": aggregation[field] = float(np.min(state_array)) elif rule == "max": aggregation[field] = float(np.max(state_array)) elif rule == "mean": aggregation[field] = float(np.mean(state_array)) elif rule == "sum": aggregation[field] = float(np.sum(state_array)) return aggregation
[docs] def get_hierarchy_levels(self) -> list[HierarchyLevel]: """Return ordered hierarchy levels from bottom to top.""" raise NotImplementedError
[docs] def get_component_group(self, component_id: int, level: str) -> str | None: """Get the group a component belongs to at a level.""" raise NotImplementedError
[docs] def get_group_components(self, group_id: str, level: str) -> list[int]: """Get all components in a group.""" raise NotImplementedError
[docs] def get_all_groups(self, level: str) -> list[str]: """Get all group IDs at a hierarchy level.""" raise NotImplementedError
[docs] def get_group_property(self, group_id: str, level: str, property: str) -> Any: """Get a property of a hierarchy group.""" raise NotImplementedError
[docs] def reset(self, context: ModelContext | None = None): """Reset hierarchy model.""" pass
[docs] @classmethod def get_parameter_spec(cls) -> dict[str, tuple[type, tuple[float, float], str]]: """Hierarchy models typically don't have numeric parameters.""" return {}
def _setup(self): """Setup hierarchy structure.""" pass
[docs] class GeneralHierarchy(HierarchyModel): """General-purpose hierarchy for any domain."""
[docs] def __init__(self, level_definitions: list[HierarchyLevel] | None = None): """Create hierarchy with user-defined levels.""" self.level_definitions = level_definitions or self._default_levels() self.assignments = {} self.groups = {} super().__init__()
def _default_levels(self) -> list[HierarchyLevel]: """Default two-level hierarchy.""" return [ HierarchyLevel("component"), HierarchyLevel( "group", "component", aggregation_rules={"condition": "min"} ), ] def _setup(self): """Validate and setup hierarchy.""" self._validate_hierarchy() def _validate_hierarchy(self): """Ensure hierarchy is well-formed.""" level_names = {level.name for level in self.level_definitions} for level in self.level_definitions: if level.parent_level and level.parent_level not in level_names: raise ValueError(f"Parent level {level.parent_level} not found")
[docs] def get_hierarchy_levels(self) -> list[HierarchyLevel]: return self.level_definitions
[docs] def assign_component(self, component_id: int, assignments: dict[str, str]): """Assign component to hierarchy groups.""" self.assignments[component_id] = assignments for level, group in assignments.items(): if level not in self.groups: self.groups[level] = {} if group not in self.groups[level]: self.groups[level][group] = {"components": set(), "properties": {}} self.groups[level][group]["components"].add(component_id)
[docs] def get_component_group(self, component_id: int, level: str) -> str | None: """Get the group a component belongs to at a level.""" return self.assignments.get(component_id, {}).get(level)
[docs] def get_group_components(self, group_id: str, level: str) -> list[int]: """Get all components in a group.""" if level in self.groups and group_id in self.groups[level]: return list(self.groups[level][group_id]["components"]) return []
[docs] def get_all_groups(self, level: str) -> list[str]: """Get all group IDs at a hierarchy level.""" return list(self.groups.get(level, {}).keys())
[docs] def set_group_property(self, group_id: str, level: str, property: str, value: Any): """Set a property for a hierarchy group.""" if level not in self.groups: self.groups[level] = {} if group_id not in self.groups[level]: self.groups[level][group_id] = {"components": set(), "properties": {}} self.groups[level][group_id]["properties"][property] = value
[docs] def get_group_property(self, group_id: str, level: str, property: str) -> Any: """Get a property of a hierarchy group.""" if level in self.groups and group_id in self.groups[level]: return self.groups[level][group_id]["properties"].get(property) return None
[docs] def compute_group_metric( self, group_id: str, level: str, values: np.ndarray, aggregation: str = "mean" ) -> float: """Compute aggregated metric for a group.""" components = self.get_group_components(group_id, level) if not components: return 0.0 component_values = values[components] if aggregation == "mean": return float(np.mean(component_values)) elif aggregation == "min": return float(np.min(component_values)) elif aggregation == "max": return float(np.max(component_values)) elif aggregation == "sum": return float(np.sum(component_values)) elif aggregation == "weighted_mean": weights = [ self.get_group_property(group_id, level, "weight") or 1.0 for _ in components ] return float(np.average(component_values, weights=weights)) else: return float(np.mean(component_values))
[docs] class SimpleHierarchy(GeneralHierarchy): """Simple two-level hierarchy for basic applications."""
[docs] def __init__(self): """Create a simple component-system hierarchy.""" levels = [ HierarchyLevel("component"), HierarchyLevel( "system", "component", aggregation_rules={"condition": "min", "cost": "sum"}, ), ] super().__init__(levels)
[docs] class MultiLevelHierarchy(GeneralHierarchy): """Multi-level hierarchy for complex systems."""
[docs] def __init__(self, n_levels: int = 3): """Create a multi-level hierarchy. Args: n_levels: Number of hierarchy levels (including component level) """ levels = [HierarchyLevel("component")] for i in range(1, n_levels): parent = f"level_{i-1}" if i > 1 else "component" level_name = f"level_{i}" levels.append( HierarchyLevel( level_name, parent, aggregation_rules={"condition": "min"} ) ) super().__init__(levels)