Use Case Implementation ======================= This page explains how **EnerGNN** can be leveraged to address your own custom use cases. But first, make sure to check the :doc:`basics` and :doc:`tutorial_notebook` pages. Each use case implementation shall encompass their underlying business logic: - How input data (i.e., **contexts**, see :term:`Context`) are defined and loaded in memory, - How output data (i.e., **decisions**, see :term:`Decision`) should look like, - How the gradient shall be estimated (closed-form for supervised learning vs. Monte Carlo estimation for more complex cases, minimizing the **objective function**, see :term:`Objective Function`). Depending on the use case, very different implementation choices can be made, but all should respect the interface defined in :mod:`energnn.problem`. The following guide walks you through the major steps in implementing your own **EnerGNN** use case. Overview -------- Implementing your custom use case requires the following class implementations. 1. :term:`Problem` (:class:`~energnn.problem.Problem`) -- Implements the logic for a single **problem** instance (context, gradient, score). 2. :term:`Problem Batch` (:class:`~energnn.problem.ProblemBatch`) -- Handles **batching** of multiple problem instances together. The implementation can be optimized for parallel computation, and even leverage GPU parallelization. 3. :class:`~energnn.problem.ProblemLoader` -- Iterates over a whole dataset of problems, by returning a different **problem batch** at every iteration. All three classes should however share two common properties, :attr:`~energnn.problem.Problem.context_structure` and :attr:`~energnn.problem.Problem.decision_structure`, which define the name of the object classes and of their respective ports and features appearing in **contexts** and **decisions**. ---------- Step 0 — Define Graph Structures -------------------------------- **EnerGNN** uses the class :class:`~energnn.graph.GraphStructure` to understand the format of your data. You must define a structure for your **contexts** (see :term:`Context`) and your **decisions** (see :term:`Decision`). They are mandatory properties for your :class:`~energnn.problem.Problem`, :class:`~energnn.problem.ProblemBatch` and :class:`~energnn.problem.ProblemLoader` implementations. .. code-block:: python from energnn.graph import HyperEdgeSetStructure, GraphStructure # Example: a context graph with lines, switches, generators and loads CONTEXT_STRUCTURE: GraphStructure = GraphStructure.from_dict(hyper_edge_set_structure_dict={ "lines": HyperEdgeSetStructure.from_list(port_list=["bus1", "bus2"], feature_list=["r", "x"]), "switches": HyperEdgeSetStructure.from_list(port_list=["bus1", "bus2"], feature_list=None), "generators": HyperEdgeSetStructure.from_list(port_list=["bus"], feature_list=["p0", "q0"]), "loads": HyperEdgeSetStructure.from_list(port_list=["bus"], feature_list=["p", "q"]), }) # Let us say that we wish to predict a log-probability for each switch to be open, # along with a generation variation. DECISION_STRUCTURE: GraphStructure = GraphStructure.from_dict(hyper_edge_set_structure_dict={ "switches": HyperEdgeSetStructure.from_list(port_list=None, feature_list=["log_prob"]), "generators": HyperEdgeSetStructure.from_list(port_list=None, feature_list=["delta_p"]), }) **Important constraints:** 1. All classes in the **context** shall be at least of order 1 (i.e., have 1 or more ports); 2. All classes appearing in the **decision** shall also be appearing in the **context**; 3. No port can be predicted by the GNN, so all attributes :code:`port_list` in the decision structure shall be None; For now, there is no support for global features (i.e., that would not be borne by a specific object), but feel free to reach out if that's something you would like to see included. ------------------ Step 1 — Implement the Problem Interface ---------------------------------------- Your implementation of the class :class:`~energnn.problem.Problem` should represent a single problem instance. You must implement the following properties: - :attr:`~energnn.problem.Problem.context_structure`, - :attr:`~energnn.problem.Problem.decision_structure`, And the following methods: - :meth:`~energnn.problem.Problem.get_context`: Returns the **context** graph :math:`x`, instantiated as a :class:`~energnn.graph.JaxGraph`, - :meth:`~energnn.problem.Problem.get_gradient`: Computes :math:`\nabla_y f(y;x)` for a given **decision** :math:`y`, instantiated as a :class:`~energnn.graph.JaxGraph`, - :meth:`~energnn.problem.Problem.get_score`: Computes :math:`f(y;x)` for a given **decision** :math:`y` as a :code:`float`. **Tracking relevant quantities**: All three methods have a key word argument :attr:`get_info` to trigger an optional behavior. If :code:`True`, these methods return optional dictionaries that are passed to your experiment tracker. It's useful for debugging and tracking, but not necessary in your first implementation. **Data representation**: Contexts, decisions and gradients are all instantiated as :class:`~energnn.graph.JaxGraph`, which is a version of :class:`~energnn.graph.Graph` designed to work seamlessly with JAX. **Decoupling gradients and score**: You can use different objective functions in :meth:`~energnn.problem.Problem.get_score` and in :meth:`~energnn.problem.Problem.get_gradient`. For instance, you can use a non-differentiable function :math:`f` as a score, and a differentiable function :math:`f'` in :meth:`~energnn.problem.Problem.get_gradient`. Loosely speaking, :meth:`~energnn.problem.Problem.get_score` just has to return a scalar value that quantifies how good a decision is, and :meth:`~energnn.problem.Problem.get_gradient` just has to return the opposite of a direction of improvement for a decision. .. code-block:: python from typing import Any import jax.numpy as jnp from energnn.graph import Graph, JaxGraph, GraphStructure from energnn.problem import Problem class MyProblem(Problem): def __init__(self, path: Any): # Implement your own data import, and store relevant state data self.context, self.state: tuple[JaxGraph, Any] = self._import_from_pypowsybl(path) @property def context_structure(self) -> GraphStructure: return CONTEXT_STRUCTURE @property def decision_structure(self) -> GraphStructure: return DECISION_STRUCTURE def get_context(self, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]: return self.context, {} def get_gradient(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]: # Implement your own gradient estimation method grad: JaxGraph = self._estimate_gradient(decision, self.state) return grad, {} def get_score(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[float, dict[str, Any]]: # Implement your own score estimation method grad: float = self._estimate_score(decision, self.state) return grad, {} def save(self, path): # Implement your own save method pass @classmethod def load(cls, path): # Implement your own load method pass Step 2 — Handle Batching (Problem Batch) ---------------------------------------- To train efficiently on GPUs, multiple problems are grouped together into a **problem batch** (:term:`Problem Batch`). The batch interface mirrors the :term:`Problem` interface but operates on concatenated graphs. It is very common that the different problem instances within a batch have a different amount of objects for a given class. For instance, consider a batch with 2 instances, where: - The first **context** has 5 switches, - The second **context** has 7 switches. To collate them together, we have to pad the first **context** with 2 fictitious switches, so that the two **contexts** end up having the same number of switches. The following code snippet shows how to do so. .. code-block:: from energnn.graph import Graph, GraphShape, collate_graphs, max_shape, separate_graphs context_1: Graph = ... context_2: Graph = ... # Step 1 : Get the two graph shapes # The true_shape property computes the number of non fictitious objects of each class shape_1 = context_1.true_shape shape_2 = context_2.true_shape # Step 2 : Compute the largest shape # It computes for each object class the maximum number of objects in the list max_shape = max_shape([shape_1, shape_2]) # Step 3 : Pad all contexts with fictitious objects if required context_1.pad(max_shape) context_2.pad(max_shape) # The true_shape property is not altered by the padding, but the current_shape is. # Step 4 : Collate contexts together context_batch = collate_graphs([context_1, context_2]) # Et voilà, you have a context batch filled with fictitious objects if necessary. # We can pass it to a model to compute a batch of decisions. # The EnerGNN models are implemented to return 0 values on fictitious objects, # And to keep track of which objects actually are fictitious or not. decision_batch = my_model.forward_batch(context_batch) # Wait... What if we need to split this batch of decisions, # and get rid of fictitious objects ? # Step 5 : Split decisions apart decision_1, decision_2 = separate_graphs(decision_batch) # Step 6 : Unpad decisions decision_1.unpad() decision_2.unpad() # And now we have decisions without any fictitious object! The following :class:`~energnn.problem.ProblemBatch` implementation assumes that : - Single :class:`~energnn.problem.Problem` instances have been generated and saved beforehand, - The gradient computation can be performed in batch, - A :code:`max_shape` has been computed over the whole dataset. .. code-block:: python from typing import Any from energnn.problem import ProblemBatch, Problem from energnn.graph import JaxGraph, Graph, GraphStructure, collate_graphs class MyBatch(ProblemBatch): def __init__(self, path_list: list[str], max_shape: GraphShape): self.problems: list[MyProblem] = [MyProblem.load(path) for path in path_list] # Get all contexts, pad them and collate them together. context_list, _ = zip([pb.get_context() for pb in self.problems]) np_context_list = [context.to_numpy_graph() for context in context_list] [np_context.pad(max_shape) for np_context in np_context_list] np_context_batch = collate_graphs(np_context_list) self.context_batch = JaxGraph.from_numpy_graph(np_context_batch) @property def context_structure(self) -> GraphStructure: return CONTEXT_STRUCTURE @property def decision_structure(self) -> GraphStructure: return DECISION_STRUCTURE def get_context(self, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]: return self.context_batch, {} def get_gradient(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]: batch_grad = self._compute_batch_grad(decision, self.problem_list) return batch_grad, {} def get_score(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[list[float], dict[str, Any]]: score_list = self._compute_score_list(decision, self.problem_list) return score_list, {} Step 3 — Data Loading (ProblemLoader) ------------------------------------- The :class:`energnn.problem.ProblemLoader` is an iterator that yields batches. .. code-block:: python from typing import Any, Iterator from energnn.problem import ProblemLoader, ProblemBatch class MyLoader(ProblemLoader): def __init__(self, dataset: list[str], batch_size: int, shuffle: bool = False): self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle self._current_idx = 0 def __iter__(self) -> Iterator[ProblemBatch]: # Handle shuffling if needed self._current_idx = 0 return self def __next__(self) -> ProblemBatch: if self._current_idx >= len(self.dataset): raise StopIteration # Slice dataset and return a MyBatch instance path_list = self.dataset[self._current_idx:self._current_idx+self.batch_size] self._current_idx += self.batch_size return MyBatch(path_list) def __len__(self) -> int: return len(self.dataset) // self.batch_size Interface Checklist ------------------- When implementing your custom use case, ensure these requirements are met: - ``context_structure`` and ``decision_structure`` properties are defined. - ``get_context()`` returns a :class:`energnn.graph.JaxGraph`. - ``get_gradient()`` returns a :class:`energnn.graph.JaxGraph` with the same topology as the decision. - ``get_score()`` returns a scalar (for :class:`~energnn.problem.Problem`) or a list of scalars (for :class:`~energnn.problem.ProblemBatch`). - Graphs are correctly converted between :class:`~energnn.graph.Graph` (NumPy-based, useful for building/collating) and :class:`~energnn.graph.JaxGraph` (JAX-based, used by the models). Summary ------- By implementing these interfaces, your problem becomes fully compatible with EnerGNN's models and trainers. You can find more practical examples in the :doc:`tutorial_notebook` or by looking at the ``tests/utils.py`` file in the repository. Next steps ---------- - See :doc:`basics` for more details on H2MG graphs. - Visit the :doc:`reference/index` for the full API specification of the problem module.