Tutorial¶
In this energnn tutorial, we will review :
How to install
energnn.The interaction with typical implementations of
energnn.problem.Problem,energnn.problem.ProblemBatchandenergnn.problem.ProblemLoader.The creation of a GNN model using
energnn.model.ready_to_use,The training of the model with
energnn.trainer.Trainer,The usage of the trained model.
Installation¶
To install the latest stable release of energnn on CPU,
pip install energnn
For the GPU version,
pip install energnn --extra gpu
Problem Class¶
Let’s consider the following use case. Knowing the pair \((A, b)\), we wish to find \(x\) such that \(Ax = b\). Let us generate a random problem instance and explore its interface.
[21]:
from energnn.problem.example import LinearSystemProblemGenerator
pb_generator = LinearSystemProblemGenerator(seed=7, n_max=4)
problem = pb_generator.generate_problem()
Context Graph¶
The input of our GNN model is referred to as the context, instantiated as an energnn.graph.Graph object.
In this case, it is the pair \((A, b)\), framed as a Hyper Heterogeneous Multi Graph (H2MG). The matrix \(A\) is represented by a hyper-edge set called arrow, and the vectore \(b\) is represented by a hyper-edge set called source.
[22]:
# Let us explore the context structure
print(problem.context_structure)
Ports Features
Name
arrow [from, to] [value]
source [id] [value]
Each hyper-edge set has ports and / or features. Ports define the connectivity of the graph and associate a hyper-edge with an integer address.
[23]:
# Print the context graph associated to the problem instance
context, _ = problem.get_context()
print(context)
arrow
ports features
from to value
object_id
0 0.0 1.0 1.457083
1 0.0 3.0 3.200915
2 1.0 2.0 2.025772
3 2.0 1.0 1.216732
4 3.0 2.0 0.748781
source
ports features
id value
object_id
0 0.0 -0.155426
1 1.0 1.770309
2 2.0 -0.406632
3 3.0 0.654355
The context of this specific problem instance has:
5 arrow objects,
4 source objects.
Decision Graph¶
The output of our GNN model is referred to as the decision, instantiated as an energnn.graph.Graph object.
In this case, it is the variable \(x\).
This specific problem class has a helper method called get_zero_decision (not part of the mandatory interface), that returns a decision of the right shape and structure, filled with zeros.
[24]:
# Let us explore the decision structure
print(problem.decision_structure)
Ports Features
Name
source None [value]
Notice that decisions concern only a subset of the classes available in the context, and that they have no port and just features.
[25]:
# Print the context graph associated to the problem instance
decision, _ = problem.get_zero_decision()
print(decision)
source
features
value
object_id
0 0.0
1 -0.0
2 0.0
3 0.0
Objective Function¶
The score of a given decision is instantiated as a float.
In this case, we use the Mean Squared Error \(\frac{1}{2} \Vert x - x^\star \Vert^2\), where \(x^\star\) is the solution.
We can evaluate it by injecting the zero decision we have just retrieved.
[26]:
score, _ = problem.get_score(decision=decision)
print(score)
0.2227775752544403
Gradient Graph¶
The gradient of the objective function is instantiated as an energnn.graph.Graph object.
In this case, it is the vector \(x-x^\star\). Notice that for more complex use cases, the gradient can have more complex expressions, or even require Monte-Carlo simulations.
We can evaluate it by injecting the zero decision.
[27]:
gradient, _ = problem.get_gradient(decision=decision)
print(gradient)
source
features
value
object_id
0 -0.070729
1 0.334200
2 -0.873894
3 -0.103574
Notice that this gradient is the exact same type of object as the decision.
Just as a quick sanity check, we can perform a gradient descent and make sure that the objective decreases.
[28]:
from energnn.graph import JaxGraph
alpha = 0.5
objective, _ = problem.get_score(decision=decision)
print(f"Step 0, objective = {objective}")
for i in range(10):
gradient, _ = problem.get_gradient(decision=decision)
# Update decision
numpy_gradient = gradient.to_numpy_graph() # For now, we need to convert the gradient to a numpy graph.
numpy_decision = decision.to_numpy_graph()
numpy_decision.feature_flat_array -= alpha * numpy_gradient.feature_flat_array
decision = JaxGraph.from_numpy_graph(numpy_decision)
objective, _ = problem.get_score(decision=decision)
print(f"Step {i}, objective = {objective}")
Step 0, objective = 0.2227775752544403
Step 0, objective = 0.05569439381361008
Step 1, objective = 0.013923597522079945
Step 2, objective = 0.003480899380519986
Step 3, objective = 0.0008702246705070138
Step 4, objective = 0.00021755615307483822
Step 5, objective = 5.438903463073075e-05
Step 6, objective = 1.3597240467788652e-05
Step 7, objective = 3.399258957870188e-06
Step 8, objective = 8.49789046242222e-07
Step 9, objective = 2.1243486969524383e-07
The objective function successfully decreases! Now, let’s explore how multiple problems can be batched together.
Problem Batch¶
Interacting with a single problem instance is useful at inference time, or for debugging purposes. But to train a whole Graph Neural Network model, it is necessary to process batches of problem instances altogether.
[29]:
from energnn.problem.example import LinearSystemProblemGenerator
pb_generator = LinearSystemProblemGenerator(seed=9, n_max=3)
problem_batch = pb_generator.generate_problem_batch(batch_size=3)
# Let us explore the context and decision structures
print("Context Structure:\n", problem_batch.context_structure, "\n")
print("Decision Structure:\n", problem_batch.decision_structure)
Context Structure:
Ports Features
Name
arrow [from, to] [value]
source [id] [value]
Decision Structure:
Ports Features
Name
source None [value]
Here, contexts are still graphs, but this time with an extra dimension:
[30]:
context, _ = problem_batch.get_context()
print(context)
arrow
ports features
from to value
batch_id object_id
0 0 0.0 0.0 -1.116066
1 1.0 0.0 -0.481135
2 1.0 1.0 -1.517331
3 1.0 2.0 -0.490872
4 2.0 1.0 -0.647947
5 2.0 2.0 0.635891
6 0.0 0.0 0.000000
7 0.0 0.0 0.000000
8 0.0 0.0 0.000000
1 0 0.0 0.0 -0.857040
1 0.0 1.0 1.528224
2 0.0 2.0 0.904988
3 1.0 0.0 0.541645
4 1.0 1.0 0.701052
5 1.0 2.0 -0.054635
6 2.0 0.0 0.081804
7 2.0 1.0 -1.281731
8 2.0 2.0 0.158457
2 0 0.0 0.0 -0.745505
1 0.0 0.0 0.000000
2 0.0 0.0 0.000000
3 0.0 0.0 0.000000
4 0.0 0.0 0.000000
5 0.0 0.0 0.000000
6 0.0 0.0 0.000000
7 0.0 0.0 0.000000
8 0.0 0.0 0.000000
source
ports features
id value
batch_id object_id
0 0 0.0 -1.942086
1 1.0 -1.634691
2 2.0 0.257661
1 0 0.0 -1.926273
1 1.0 -0.110379
2 2.0 0.666537
2 0 0.0 -0.104837
1 0.0 0.000000
2 0.0 0.000000
Notice that the different contexts of the batch do not have the same connectivity, and do not have the same number of arrow and source objects. To batch the different contexts together, it is thus necessary to pad them with zeros.
Still, a ProblemBatch can be handled in a very similar way to a single Problem.
[31]:
alpha = 0.5
decision, _ = problem_batch.get_zero_decision()
objective, _ = problem_batch.get_score(decision=decision)
print(f"Step 0, objective = {objective}")
for i in range(10):
gradient, _ = problem_batch.get_gradient(decision=decision)
# Update decision
numpy_gradient = gradient.to_numpy_graph()
numpy_decision = decision.to_numpy_graph()
numpy_decision.feature_flat_array -= alpha * numpy_gradient.feature_flat_array
decision = JaxGraph.from_numpy_graph(numpy_decision)
objective, _ = problem_batch.get_score(decision=decision)
print(f"Step {i}, objective = {objective}")
Step 0, objective = [1.205530047416687, 0.3518087863922119, 0.006591798271983862]
Step 0, objective = [0.30138251185417175, 0.08795219659805298, 0.0016479495679959655]
Step 1, objective = [0.07534561306238174, 0.021988049149513245, 0.0004119873046875]
Step 2, objective = [0.018836403265595436, 0.005497012287378311, 0.000102996826171875]
Step 3, objective = [0.0047090970911085606, 0.0013742543524131179, 2.574920654296875e-05]
Step 4, objective = [0.0011772741563618183, 0.00034356300602667034, 6.4373016357421875e-06]
Step 5, objective = [0.00029431749135255814, 8.589089702581987e-05, 1.6093254089355469e-06]
Step 6, objective = [7.357935101026669e-05, 2.1472886146511883e-05, 4.023313522338867e-07]
Step 7, objective = [1.8394850485492498e-05, 5.3682501857110765e-06, 1.0058283805847168e-07]
Step 8, objective = [4.598733994498616e-06, 1.342021732853027e-06, 2.514570951461792e-08]
Step 9, objective = [1.149680656453711e-06, 3.3547598832228687e-07, 6.28642737865448e-09]
Notice that there is a scores are now lists of float.
Problem Loader¶
Being able to process problem instances per batch is nice, but not enough. To train a Graph Neural Network, we’ll need to iterate over multiple minibatches of problem instances. That’s where the ProblemLoader class comes in.
[32]:
from energnn.problem.example import LinearSystemProblemLoader
problem_loader = LinearSystemProblemLoader(batch_size=4, seed=7, dataset_size=16, n_max=4)
# Let us explore the context and decision structures
print("Context Structure:\n", problem_loader.context_structure, "\n")
print("Decision Structure:\n", problem_loader.decision_structure)
Context Structure:
Ports Features
Name
arrow [from, to] [value]
source [id] [value]
Decision Structure:
Ports Features
Name
source None [value]
It allows to iterate over batches of problems.
[33]:
for problem_batch in problem_loader:
context, _ = problem_batch.get_context()
decision, _ = problem_batch.get_zero_decision()
objective, _ = problem_batch.get_score(decision=decision)
print("Objective:", objective)
Objective: [0.2227775752544403, 0.5266321897506714, 0.04437382146716118, 0.17635096609592438]
Objective: [0.3786073327064514, 0.6252117156982422, 0.338590145111084, 0.010212971828877926]
Objective: [1.4664279222488403, 0.7719758749008179, 1.2088079452514648, 0.7055290937423706]
Objective: [1.4254088401794434, 2.1467173099517822, 0.37498927116394043, 0.3698219358921051]
Graph Neural Network Model¶
Let us instantiate a small Graph Neural Network model, that is adapted to the context and decision structure of our problem class.
[34]:
from energnn.model.ready_to_use import TinyRecurrentEquivariantGNN
model = TinyRecurrentEquivariantGNN(
in_structure=problem_loader.context_structure,
out_structure=problem_loader.decision_structure
)
Make sure that your model is in evaluation mode first!
[35]:
model.eval() # Set the model in evaluation mode.
# model.train() # To set the model in train mode.
It is able to take as input a context and return a decision.
[36]:
problem = pb_generator.generate_problem()
context, _ = problem.get_context()
decision, _ = model(context)
print(decision)
source
features
value
object_id
0 0.237434
It can also process batches of contexts and return batches of decisions.
[37]:
problem_batch = pb_generator.generate_problem_batch(batch_size=4)
context, _ = problem_batch.get_context()
decision, _ = model.forward_batch(graph=context)
print(decision)
source
features
value
batch_id object_id
0 0 0.294713
1 0.000000
2 0.000000
1 0 0.000000
1 3.260542
2 0.000000
2 0 0.326413
1 0.000000
2 0.000000
3 0 -0.784929
1 -0.000000
2 -0.000000
Trainer¶
Let us train our Graph Neural Network model over a problem loader. The core training loop is defined by the following pseudocode.
for problem_batch in problem_loader:
context, _ = problem_batch.get_context()
decision, _ = model.forward_batch(context)
gradient, _ = problem_batch.get_gradient(decision)
model.backprop(gradient)
In practice, we use energnn.trainer to implement the training logic, and allow to use :
optaxfor the optimizer,orbaxfor checkpointing and saving/loading models.
[38]:
from energnn.trainer import Trainer
import optax
trainer = Trainer(model=model, gradient_transformation=optax.adam(learning_rate=1e-3))
The training is performed by iterating over a train loader, and the validation score is periodically computed on a validation loader.
[39]:
train_loader = LinearSystemProblemLoader(seed=7, dataset_size=64, batch_size=4, n_max=3)
val_loader = LinearSystemProblemLoader(seed=8, dataset_size=8, batch_size=4, n_max=3)
[40]:
_ = trainer.train(
train_loader=train_loader,
val_loader=val_loader,
eval_before_training=True,
n_epochs=3,
)
Validation: 100%|██████████| 2/2 [00:03<00:00, 1.83s/batch, score=2.1002e+00]
Epoch 1/3: 100%|██████████| 16/16 [00:22<00:00, 1.42s/batch]
Validation: 100%|██████████| 2/2 [00:01<00:00, 1.26batch/s, score=1.0779e+00]
Epoch 2/3: 100%|██████████| 16/16 [00:19<00:00, 1.21s/batch]
Validation: 100%|██████████| 2/2 [00:01<00:00, 1.11batch/s, score=9.9698e-01]
Epoch 3/3: 100%|██████████| 16/16 [00:18<00:00, 1.13s/batch]
Validation: 100%|██████████| 2/2 [00:01<00:00, 1.21batch/s, score=9.3471e-01]