Use Case Implementation

This page explains how EnerGNN can be leveraged to address your own custom use cases. But first, make sure to check the Basics and Tutorial pages.

Each use case implementation shall encompass their underlying business logic:

  • How input data (i.e., contexts, see Context) are defined and loaded in memory,

  • How output data (i.e., decisions, see Decision) should look like,

  • How the gradient shall be estimated (closed-form for supervised learning vs. Monte Carlo estimation for more complex cases, minimizing the objective function, see Objective Function).

Depending on the use case, very different implementation choices can be made, but all should respect the interface defined in energnn.problem.

The following guide walks you through the major steps in implementing your own EnerGNN use case.

Overview

Implementing your custom use case requires the following class implementations.

  1. Problem (Problem) – Implements the logic for a single problem instance (context, gradient, score).

  2. Problem Batch (ProblemBatch) – Handles batching of multiple problem instances together. The implementation can be optimized for parallel computation, and even leverage GPU parallelization.

  3. ProblemLoader – Iterates over a whole dataset of problems, by returning a different problem batch at every iteration.

All three classes should however share two common properties, context_structure and decision_structure, which define the name of the object classes and of their respective ports and features appearing in contexts and decisions.


Step 0 — Define Graph Structures

EnerGNN uses the class GraphStructure to understand the format of your data. You must define a structure for your contexts (see Context) and your decisions (see Decision). They are mandatory properties for your Problem, ProblemBatch and ProblemLoader implementations.

from energnn.graph import HyperEdgeSetStructure, GraphStructure

# Example: a context graph with lines, switches, generators and loads
CONTEXT_STRUCTURE: GraphStructure = GraphStructure.from_dict(hyper_edge_set_structure_dict={
    "lines": HyperEdgeSetStructure.from_list(port_list=["bus1", "bus2"], feature_list=["r", "x"]),
    "switches": HyperEdgeSetStructure.from_list(port_list=["bus1", "bus2"], feature_list=None),
    "generators": HyperEdgeSetStructure.from_list(port_list=["bus"], feature_list=["p0", "q0"]),
    "loads": HyperEdgeSetStructure.from_list(port_list=["bus"], feature_list=["p", "q"]),
})

# Let us say that we wish to predict a log-probability for each switch to be open,
# along with a generation variation.
DECISION_STRUCTURE: GraphStructure = GraphStructure.from_dict(hyper_edge_set_structure_dict={
    "switches": HyperEdgeSetStructure.from_list(port_list=None, feature_list=["log_prob"]),
    "generators": HyperEdgeSetStructure.from_list(port_list=None, feature_list=["delta_p"]),
})

Important constraints:

  1. All classes in the context shall be at least of order 1 (i.e., have 1 or more ports);

  2. All classes appearing in the decision shall also be appearing in the context;

  3. No port can be predicted by the GNN, so all attributes port_list in the decision structure shall be None;

For now, there is no support for global features (i.e., that would not be borne by a specific object), but feel free to reach out if that’s something you would like to see included.


Step 1 — Implement the Problem Interface

Your implementation of the class Problem should represent a single problem instance. You must implement the following properties:

And the following methods:

Tracking relevant quantities: All three methods have a key word argument get_info to trigger an optional behavior. If True, these methods return optional dictionaries that are passed to your experiment tracker. It’s useful for debugging and tracking, but not necessary in your first implementation.

Data representation: Contexts, decisions and gradients are all instantiated as JaxGraph, which is a version of Graph designed to work seamlessly with JAX.

Decoupling gradients and score: You can use different objective functions in get_score() and in get_gradient(). For instance, you can use a non-differentiable function \(f\) as a score, and a differentiable function \(f'\) in get_gradient(). Loosely speaking, get_score() just has to return a scalar value that quantifies how good a decision is, and get_gradient() just has to return the opposite of a direction of improvement for a decision.

from typing import Any
import jax.numpy as jnp
from energnn.graph import Graph, JaxGraph, GraphStructure
from energnn.problem import Problem

class MyProblem(Problem):
    def __init__(self, path: Any):
        # Implement your own data import, and store relevant state data
        self.context, self.state: tuple[JaxGraph, Any] = self._import_from_pypowsybl(path)

    @property
    def context_structure(self) -> GraphStructure:
        return CONTEXT_STRUCTURE

    @property
    def decision_structure(self) -> GraphStructure:
        return DECISION_STRUCTURE

    def get_context(self, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]:
        return self.context, {}

    def get_gradient(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]:
        # Implement your own gradient estimation method
        grad: JaxGraph = self._estimate_gradient(decision, self.state)
        return grad, {}

    def get_score(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[float, dict[str, Any]]:
        # Implement your own score estimation method
        grad: float = self._estimate_score(decision, self.state)
        return grad, {}

    def save(self, path):
        # Implement your own save method
        pass

    @classmethod
    def load(cls, path):
        # Implement your own load method
        pass

Step 2 — Handle Batching (Problem Batch)

To train efficiently on GPUs, multiple problems are grouped together into a problem batch (Problem Batch). The batch interface mirrors the Problem interface but operates on concatenated graphs.

It is very common that the different problem instances within a batch have a different amount of objects for a given class. For instance, consider a batch with 2 instances, where:

  • The first context has 5 switches,

  • The second context has 7 switches.

To collate them together, we have to pad the first context with 2 fictitious switches, so that the two contexts end up having the same number of switches. The following code snippet shows how to do so.

from energnn.graph import Graph, GraphShape, collate_graphs, max_shape, separate_graphs

context_1: Graph = ...
context_2: Graph = ...

# Step 1 : Get the two graph shapes
# The true_shape property computes the number of non fictitious objects of each class
shape_1 = context_1.true_shape
shape_2 = context_2.true_shape

# Step 2 : Compute the largest shape
# It computes for each object class the maximum number of objects in the list
max_shape = max_shape([shape_1, shape_2])

# Step 3 : Pad all contexts with fictitious objects if required
context_1.pad(max_shape)
context_2.pad(max_shape)

# The true_shape property is not altered by the padding, but the current_shape is.

# Step 4 : Collate contexts together
context_batch = collate_graphs([context_1, context_2])

# Et voilà, you have a context batch filled with fictitious objects if necessary.
# We can pass it to a model to compute a batch of decisions.
# The EnerGNN models are implemented to return 0 values on fictitious objects,
# And to keep track of which objects actually are fictitious or not.
decision_batch = my_model.forward_batch(context_batch)

# Wait... What if we need to split this batch of decisions,
# and get rid of fictitious objects ?

# Step 5 : Split decisions apart
decision_1, decision_2 = separate_graphs(decision_batch)

# Step 6 : Unpad decisions
decision_1.unpad()
decision_2.unpad()

# And now we have decisions without any fictitious object!

The following ProblemBatch implementation assumes that :

  • Single Problem instances have been generated and saved beforehand,

  • The gradient computation can be performed in batch,

  • A max_shape has been computed over the whole dataset.

from typing import Any
from energnn.problem import ProblemBatch, Problem
from energnn.graph import JaxGraph, Graph, GraphStructure, collate_graphs

class MyBatch(ProblemBatch):
    def __init__(self, path_list: list[str], max_shape: GraphShape):

        self.problems: list[MyProblem] = [MyProblem.load(path) for path in path_list]

        # Get all contexts, pad them and collate them together.
        context_list, _ = zip([pb.get_context() for pb in self.problems])
        np_context_list = [context.to_numpy_graph() for context in context_list]
        [np_context.pad(max_shape) for np_context in np_context_list]
        np_context_batch = collate_graphs(np_context_list)
        self.context_batch = JaxGraph.from_numpy_graph(np_context_batch)

    @property
    def context_structure(self) -> GraphStructure:
        return CONTEXT_STRUCTURE

    @property
    def decision_structure(self) -> GraphStructure:
        return DECISION_STRUCTURE

    def get_context(self, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]:
        return self.context_batch, {}

    def get_gradient(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]:
        batch_grad = self._compute_batch_grad(decision, self.problem_list)
        return batch_grad, {}

    def get_score(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[list[float], dict[str, Any]]:
        score_list = self._compute_score_list(decision, self.problem_list)
        return score_list, {}

Step 3 — Data Loading (ProblemLoader)

The energnn.problem.ProblemLoader is an iterator that yields batches.

from typing import Any, Iterator
from energnn.problem import ProblemLoader, ProblemBatch

class MyLoader(ProblemLoader):
    def __init__(self, dataset: list[str], batch_size: int, shuffle: bool = False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self._current_idx = 0

    def __iter__(self) -> Iterator[ProblemBatch]:
        # Handle shuffling if needed
        self._current_idx = 0
        return self

    def __next__(self) -> ProblemBatch:
        if self._current_idx >= len(self.dataset):
            raise StopIteration

        # Slice dataset and return a MyBatch instance
        path_list = self.dataset[self._current_idx:self._current_idx+self.batch_size]
        self._current_idx += self.batch_size
        return MyBatch(path_list)

    def __len__(self) -> int:
        return len(self.dataset) // self.batch_size

Interface Checklist

When implementing your custom use case, ensure these requirements are met:

  • context_structure and decision_structure properties are defined.

  • get_context() returns a energnn.graph.JaxGraph.

  • get_gradient() returns a energnn.graph.JaxGraph with the same topology as the decision.

  • get_score() returns a scalar (for Problem) or a list of scalars (for ProblemBatch).

  • Graphs are correctly converted between Graph (NumPy-based, useful for building/collating) and JaxGraph (JAX-based, used by the models).

Summary

By implementing these interfaces, your problem becomes fully compatible with EnerGNN’s models and trainers. You can find more practical examples in the Tutorial or by looking at the tests/utils.py file in the repository.

Next steps

  • See Basics for more details on H2MG graphs.

  • Visit the API reference for the full API specification of the problem module.