Use Case Implementation¶
This page explains how EnerGNN can be leveraged to address your own custom use cases. But first, make sure to check the Basics and Tutorial pages.
Each use case implementation shall encompass their underlying business logic:
How input data (i.e., contexts) are defined and loaded in memory,
How output data (i.e., decisions) should look like,
How the gradient shall be estimated (closed-form for supervised learning vs. Monte Carlo estimation for more complex cases).
Depending on the use case, very different implementation choices can be made, but all should respect the interface
defined in energnn.problem.
The following guide walks you through the major steps in implementing your own EnerGNN use case.
Overview¶
Implementing your custom use case requires the following class implementations.
Problem– Implements the logic for a single problem instance (context, gradient, score).ProblemBatch– Handles batching of multiple problem instances together. The implementation can be optimized for parallel computation, and even leverage GPU parallelization.ProblemLoader– Iterates over a whole dataset of problems, by returning a differentProblemBatchat every iteration.
All three classes should however share two common properties, context_structure
and decision_structure, which define the name of the object classes and of their
respective ports and features appearing in contexts and decisions.
Step 0 — Define Graph Structures¶
EnerGNN uses the class GraphStructure to understand the format of your data.
You must define a structure
for your contexts and your decisions.
They are mandatory properties for your Problem, ProblemBatch
and ProblemLoader implementations.
from energnn.graph import HyperEdgeSetStructure, GraphStructure
# Example: a context graph with lines, switches, generators and loads
CONTEXT_STRUCTURE: GraphStructure = GraphStructure.from_dict(hyper_edge_set_structure_dict={
"lines": HyperEdgeSetStructure.from_list(port_list=["bus1", "bus2"], feature_list=["r", "x"]),
"switches": HyperEdgeSetStructure.from_list(port_list=["bus1", "bus2"], feature_list=None),
"generators": HyperEdgeSetStructure.from_list(port_list=["bus"], feature_list=["p0", "q0"]),
"loads": HyperEdgeSetStructure.from_list(port_list=["bus"], feature_list=["p", "q"]),
})
# Let us say that we wish to predict a log-probability for each switch to be open,
# along with a generation variation.
DECISION_STRUCTURE: GraphStructure = GraphStructure.from_dict(hyper_edge_set_structure_dict={
"switches": HyperEdgeSetStructure.from_list(port_list=None, feature_list=["log_prob"]),
"generators": HyperEdgeSetStructure.from_list(port_list=None, feature_list=["delta_p"]),
})
Important constraints:
All classes in the context shall be at least of order 1 (i.e., have 1 or more ports);
All classes appearing in the decision shall also be appearing in the context;
No port can be predicted by the GNN, so all attributes
port_listin the decision structure shall be None;
For now, there is no support for global features (i.e., that would not be borne by a specific object), but feel free to reach out if that’s something you would like to see included.
Step 1 — Implement the Problem Interface¶
Your implementation of the class Problem should represent a single problem instance.
You must implement the following properties:
And the following methods:
get_context(): Returns the context graph \(x\), instantiated as aJaxGraph,get_gradient(): Computes \(\nabla_y f(y;x)\) for a given decision \(y\), instantiated as aJaxGraph,get_score(): Computes \(f(y;x)\) for a given decision \(y\) as afloat.
Tracking relevant quantities:
All three methods have a key word argument get_info to trigger an optional behavior.
If True, these methods return optional dictionaries that are passed to your experiment tracker.
It’s useful for debugging and tracking, but not necessary in your first implementation.
Data representation:
Contexts, decisions and gradients are all instantiated as JaxGraph, which is a version of
Graph designed to work seamlessly with JAX.
Decoupling gradients and score:
You can use different objective functions in get_score()
and in get_gradient().
For instance, you can use a non-differentiable function \(f\) as a score, and a differentiable function \(f'\)
in get_gradient().
Loosely speaking, get_score() just has to return a scalar value that quantifies how good
a decision is, and get_gradient() just has to return the opposite of a direction of
improvement for a decision.
from typing import Any
import jax.numpy as jnp
from energnn.graph import Graph, JaxGraph, GraphStructure
from energnn.problem import Problem
from energnn.problem.metadata import ProblemMetadata
class MyProblem(Problem):
def __init__(self, path: Any):
# Implement your own data import, and store relevant state data
self.context, self.state: tuple[JaxGraph, Any] = self._import_from_pypowsybl(path)
@property
def context_structure(self) -> GraphStructure:
return CONTEXT_STRUCTURE
@property
def decision_structure(self) -> GraphStructure:
return DECISION_STRUCTURE
def get_context(self, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]:
return self.context, {}
def get_gradient(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]:
# Implement your own gradient estimation method
grad: JaxGraph = self._estimate_gradient(decision, self.state)
return grad, {}
def get_score(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[float, dict[str, Any]]:
# Implement your own score estimation method
grad: float = self._estimate_score(decision, self.state)
return grad, {}
def save(self, path):
# Implement your own save method
pass
@classmethod
def load(cls, path):
# Implement your own load method
pass
Step 2 — Handle Batching (ProblemBatch)¶
To train efficiently on GPUs, multiple problems are grouped together into a energnn.problem.ProblemBatch.
The batch interface mirrors the Problem interface but operates on concatenated graphs.
It is very common that the different problem instances within a batch have a different amount of objects for a given class. For instance, consider a batch with 2 instances, where:
The first context has 5 switches,
The second context has 7 switches.
To collate them together, we have to pad the first context with 2 fictitious switches, so that the two contexts end up having the same number of switches. The following code snippet shows how to do so.
from energnn.graph import Graph, GraphShape, collate_graphs, max_shape, separate_graphs
context_1: Graph = ...
context_2: Graph = ...
# Step 1 : Get the two graph shapes
# The true_shape property computes the number of non fictitious objects of each class
shape_1 = context_1.true_shape
shape_2 = context_2.true_shape
# Step 2 : Compute the largest shape
# It computes for each object class the maximum number of objects in the list
max_shape = max_shape([shape_1, shape_2])
# Step 3 : Pad all contexts with fictitious objects if required
context_1.pad(max_shape)
context_2.pad(max_shape)
# The true_shape property is not altered by the padding, but the current_shape is.
# Step 4 : Collate contexts together
context_batch = collate_graphs([context_1, context_2])
# Et voilà, you have a context batch filled with fictitious objects if necessary.
# We can pass it to a model to compute a batch of decisions.
# The EnerGNN models are implemented to return 0 values on fictitious objects,
# And to keep track of which objects actually are fictitious or not.
decision_batch = my_model.forward_batch(context_batch)
# Wait... What if we need to split this batch of decisions,
# and get rid of fictitious objects ?
# Step 5 : Split decisions apart
decision_1, decision_2 = separate_graphs(decision_batch)
# Step 6 : Unpad decisions
decision_1.unpad()
decision_2.unpad()
# And now we have decisions without any fictitious object!
The following ProblemBatch implementation assumes that :
Single
Probleminstances have been generated and saved beforehand,The gradient computation can be performed in batch,
A
max_shapehas been computed over the whole dataset.
from typing import Any
from energnn.problem import ProblemBatch, Problem
from energnn.graph import JaxGraph, Graph, GraphStructure, collate_graphs
class MyBatch(ProblemBatch):
def __init__(self, path_list: list[str], max_shape: GraphShape):
self.problems: list[MyProblem] = [MyProblem.load(path) for path in path_list]
# Get all contexts, pad them and collate them together.
context_list, _ = zip([pb.get_context() for pb in self.problems])
np_context_list = [context.to_numpy_graph() for context in context_list]
[np_context.pad(max_shape) for np_context in np_context_list]
np_context_batch = collate_graphs(np_context_list)
self.context_batch = JaxGraph.from_numpy_graph(np_context_batch)
@property
def context_structure(self) -> GraphStructure:
return CONTEXT_STRUCTURE
@property
def decision_structure(self) -> GraphStructure:
return DECISION_STRUCTURE
def get_context(self, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]:
return self.context_batch, {}
def get_gradient(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[JaxGraph, dict[str, Any]]:
batch_grad = self._compute_batch_grad(decision, self.problem_list)
return batch_grad, {}
def get_score(self, *, decision: JaxGraph, get_info: bool = False) -> tuple[list[float], dict[str, Any]]:
score_list = self._compute_score_list(decision, self.problem_list)
return score_list, {}
Step 3 — Data Loading (ProblemLoader)¶
The energnn.problem.ProblemLoader is an iterator that yields batches.
from typing import Any, Iterator
from energnn.problem import ProblemLoader, ProblemBatch
class MyLoader(ProblemLoader):
def __init__(self, dataset: list[str], batch_size: int, shuffle: bool = False):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self._current_idx = 0
def __iter__(self) -> Iterator[ProblemBatch]:
# Handle shuffling if needed
self._current_idx = 0
return self
def __next__(self) -> ProblemBatch:
if self._current_idx >= len(self.dataset):
raise StopIteration
# Slice dataset and return a MyBatch instance
path_list = self.dataset[self._current_idx:self._current_idx+self.batch_size]
self._current_idx += self.batch_size
return MyBatch(path_list)
def __len__(self) -> int:
return len(self.dataset) // self.batch_size
Interface Checklist¶
When implementing your custom use case, ensure these requirements are met:
context_structureanddecision_structureproperties are defined.get_context()returns aenergnn.graph.JaxGraph.get_gradient()returns aenergnn.graph.JaxGraphwith the same topology as the decision.get_score()returns a scalar (forProblem) or a list of scalars (forProblemBatch).Graphs are correctly converted between
Graph(NumPy-based, useful for building/collating) andJaxGraph(JAX-based, used by the models).
Summary¶
By implementing these interfaces, your problem becomes fully compatible with EnerGNN’s models and trainers.
You can find more practical examples in the Tutorial or by looking at the tests/utils.py
file in the repository.
Next steps¶
See Basics for more details on H2MG graphs.
Visit the API reference for the full API specification of the problem module.