Tutorial

In this energnn tutorial, we will review :

  • How to install energnn.

  • The interaction with typical implementations of energnn.problem.Problem, energnn.problem.ProblemBatch and energnn.problem.ProblemLoader.

  • The creation of a GNN model using energnn.model.ready_to_use,

  • The training of the model with energnn.trainer.Trainer,

  • The usage of the trained model.

Installation

To install the latest stable release of energnn on CPU,

pip install energnn

For the GPU version,

pip install energnn --extra gpu

Problem Class

Let’s consider the following use case. Knowing the pair \((A, b)\), we wish to find \(x\) such that \(Ax = b\). Let us generate a random problem instance and explore its interface.

[21]:
from energnn.problem.example import LinearSystemProblemGenerator

pb_generator = LinearSystemProblemGenerator(seed=7, n_max=4)
problem = pb_generator.generate_problem()

Context Graph

The input of our GNN model is referred to as the context, instantiated as an energnn.graph.Graph object.

In this case, it is the pair \((A, b)\), framed as a Hyper Heterogeneous Multi Graph (H2MG). The matrix \(A\) is represented by a hyper-edge set called arrow, and the vectore \(b\) is represented by a hyper-edge set called source.

[22]:
# Let us explore the context structure
print(problem.context_structure)
             Ports Features
Name
arrow   [from, to]  [value]
source        [id]  [value]

Each hyper-edge set has ports and / or features. Ports define the connectivity of the graph and associate a hyper-edge with an integer address.

[23]:
# Print the context graph associated to the problem instance
context, _ = problem.get_context()
print(context)
arrow
          ports       features
           from   to     value
object_id
0           0.0  1.0  1.457083
1           0.0  3.0  3.200915
2           1.0  2.0  2.025772
3           2.0  1.0  1.216732
4           3.0  2.0  0.748781
source
          ports  features
             id     value
object_id
0           0.0 -0.155426
1           1.0  1.770309
2           2.0 -0.406632
3           3.0  0.654355

The context of this specific problem instance has:

  • 5 arrow objects,

  • 4 source objects.

Decision Graph

The output of our GNN model is referred to as the decision, instantiated as an energnn.graph.Graph object.

In this case, it is the variable \(x\).

This specific problem class has a helper method called get_zero_decision (not part of the mandatory interface), that returns a decision of the right shape and structure, filled with zeros.

[24]:
# Let us explore the decision structure
print(problem.decision_structure)
       Ports Features
Name
source  None  [value]

Notice that decisions concern only a subset of the classes available in the context, and that they have no port and just features.

[25]:
# Print the context graph associated to the problem instance
decision, _ = problem.get_zero_decision()
print(decision)
source
          features
             value
object_id
0              0.0
1             -0.0
2              0.0
3              0.0

Objective Function

The score of a given decision is instantiated as a float.

In this case, we use the Mean Squared Error \(\frac{1}{2} \Vert x - x^\star \Vert^2\), where \(x^\star\) is the solution.

We can evaluate it by injecting the zero decision we have just retrieved.

[26]:
score, _ = problem.get_score(decision=decision)
print(score)
0.2227775752544403

Gradient Graph

The gradient of the objective function is instantiated as an energnn.graph.Graph object.

In this case, it is the vector \(x-x^\star\). Notice that for more complex use cases, the gradient can have more complex expressions, or even require Monte-Carlo simulations.

We can evaluate it by injecting the zero decision.

[27]:
gradient, _ = problem.get_gradient(decision=decision)
print(gradient)
source
           features
              value
object_id
0         -0.070729
1          0.334200
2         -0.873894
3         -0.103574

Notice that this gradient is the exact same type of object as the decision.

Just as a quick sanity check, we can perform a gradient descent and make sure that the objective decreases.

[28]:
from energnn.graph import JaxGraph

alpha = 0.5

objective, _ = problem.get_score(decision=decision)
print(f"Step 0, objective = {objective}")

for i in range(10):
    gradient, _ = problem.get_gradient(decision=decision)

    # Update decision
    numpy_gradient = gradient.to_numpy_graph()  # For now, we need to convert the gradient to a numpy graph.
    numpy_decision = decision.to_numpy_graph()
    numpy_decision.feature_flat_array -= alpha * numpy_gradient.feature_flat_array
    decision = JaxGraph.from_numpy_graph(numpy_decision)

    objective, _ = problem.get_score(decision=decision)
    print(f"Step {i}, objective = {objective}")
Step 0, objective = 0.2227775752544403
Step 0, objective = 0.05569439381361008
Step 1, objective = 0.013923597522079945
Step 2, objective = 0.003480899380519986
Step 3, objective = 0.0008702246705070138
Step 4, objective = 0.00021755615307483822
Step 5, objective = 5.438903463073075e-05
Step 6, objective = 1.3597240467788652e-05
Step 7, objective = 3.399258957870188e-06
Step 8, objective = 8.49789046242222e-07
Step 9, objective = 2.1243486969524383e-07

The objective function successfully decreases! Now, let’s explore how multiple problems can be batched together.

Problem Batch

Interacting with a single problem instance is useful at inference time, or for debugging purposes. But to train a whole Graph Neural Network model, it is necessary to process batches of problem instances altogether.

[29]:
from energnn.problem.example import LinearSystemProblemGenerator

pb_generator = LinearSystemProblemGenerator(seed=9, n_max=3)
problem_batch = pb_generator.generate_problem_batch(batch_size=3)

# Let us explore the context and decision structures
print("Context Structure:\n", problem_batch.context_structure, "\n")
print("Decision Structure:\n", problem_batch.decision_structure)
Context Structure:
              Ports Features
Name
arrow   [from, to]  [value]
source        [id]  [value]

Decision Structure:
        Ports Features
Name
source  None  [value]

Here, contexts are still graphs, but this time with an extra dimension:

[30]:
context, _ = problem_batch.get_context()
print(context)
arrow
                   ports       features
                    from   to     value
batch_id object_id
0        0           0.0  0.0 -1.116066
         1           1.0  0.0 -0.481135
         2           1.0  1.0 -1.517331
         3           1.0  2.0 -0.490872
         4           2.0  1.0 -0.647947
         5           2.0  2.0  0.635891
         6           0.0  0.0  0.000000
         7           0.0  0.0  0.000000
         8           0.0  0.0  0.000000
1        0           0.0  0.0 -0.857040
         1           0.0  1.0  1.528224
         2           0.0  2.0  0.904988
         3           1.0  0.0  0.541645
         4           1.0  1.0  0.701052
         5           1.0  2.0 -0.054635
         6           2.0  0.0  0.081804
         7           2.0  1.0 -1.281731
         8           2.0  2.0  0.158457
2        0           0.0  0.0 -0.745505
         1           0.0  0.0  0.000000
         2           0.0  0.0  0.000000
         3           0.0  0.0  0.000000
         4           0.0  0.0  0.000000
         5           0.0  0.0  0.000000
         6           0.0  0.0  0.000000
         7           0.0  0.0  0.000000
         8           0.0  0.0  0.000000
source
                   ports  features
                      id     value
batch_id object_id
0        0           0.0 -1.942086
         1           1.0 -1.634691
         2           2.0  0.257661
1        0           0.0 -1.926273
         1           1.0 -0.110379
         2           2.0  0.666537
2        0           0.0 -0.104837
         1           0.0  0.000000
         2           0.0  0.000000

Notice that the different contexts of the batch do not have the same connectivity, and do not have the same number of arrow and source objects. To batch the different contexts together, it is thus necessary to pad them with zeros.

Still, a ProblemBatch can be handled in a very similar way to a single Problem.

[31]:
alpha = 0.5

decision, _ = problem_batch.get_zero_decision()
objective, _ = problem_batch.get_score(decision=decision)
print(f"Step 0, objective = {objective}")

for i in range(10):
    gradient, _ = problem_batch.get_gradient(decision=decision)

    # Update decision
    numpy_gradient = gradient.to_numpy_graph()
    numpy_decision = decision.to_numpy_graph()
    numpy_decision.feature_flat_array -= alpha * numpy_gradient.feature_flat_array
    decision = JaxGraph.from_numpy_graph(numpy_decision)

    objective, _ = problem_batch.get_score(decision=decision)
    print(f"Step {i}, objective = {objective}")
Step 0, objective = [1.205530047416687, 0.3518087863922119, 0.006591798271983862]
Step 0, objective = [0.30138251185417175, 0.08795219659805298, 0.0016479495679959655]
Step 1, objective = [0.07534561306238174, 0.021988049149513245, 0.0004119873046875]
Step 2, objective = [0.018836403265595436, 0.005497012287378311, 0.000102996826171875]
Step 3, objective = [0.0047090970911085606, 0.0013742543524131179, 2.574920654296875e-05]
Step 4, objective = [0.0011772741563618183, 0.00034356300602667034, 6.4373016357421875e-06]
Step 5, objective = [0.00029431749135255814, 8.589089702581987e-05, 1.6093254089355469e-06]
Step 6, objective = [7.357935101026669e-05, 2.1472886146511883e-05, 4.023313522338867e-07]
Step 7, objective = [1.8394850485492498e-05, 5.3682501857110765e-06, 1.0058283805847168e-07]
Step 8, objective = [4.598733994498616e-06, 1.342021732853027e-06, 2.514570951461792e-08]
Step 9, objective = [1.149680656453711e-06, 3.3547598832228687e-07, 6.28642737865448e-09]

Notice that there is a scores are now lists of float.

Problem Loader

Being able to process problem instances per batch is nice, but not enough. To train a Graph Neural Network, we’ll need to iterate over multiple minibatches of problem instances. That’s where the ProblemLoader class comes in.

[32]:
from energnn.problem.example import LinearSystemProblemLoader

problem_loader = LinearSystemProblemLoader(batch_size=4, seed=7, dataset_size=16, n_max=4)

# Let us explore the context and decision structures
print("Context Structure:\n", problem_loader.context_structure, "\n")
print("Decision Structure:\n", problem_loader.decision_structure)
Context Structure:
              Ports Features
Name
arrow   [from, to]  [value]
source        [id]  [value]

Decision Structure:
        Ports Features
Name
source  None  [value]

It allows to iterate over batches of problems.

[33]:
for problem_batch in problem_loader:
    context, _ = problem_batch.get_context()
    decision, _ = problem_batch.get_zero_decision()
    objective, _ = problem_batch.get_score(decision=decision)
    print("Objective:", objective)
Objective: [0.2227775752544403, 0.5266321897506714, 0.04437382146716118, 0.17635096609592438]
Objective: [0.3786073327064514, 0.6252117156982422, 0.338590145111084, 0.010212971828877926]
Objective: [1.4664279222488403, 0.7719758749008179, 1.2088079452514648, 0.7055290937423706]
Objective: [1.4254088401794434, 2.1467173099517822, 0.37498927116394043, 0.3698219358921051]

Graph Neural Network Model

Let us instantiate a small Graph Neural Network model, that is adapted to the context and decision structure of our problem class.

[34]:
from energnn.model.ready_to_use import TinyRecurrentEquivariantGNN

model = TinyRecurrentEquivariantGNN(
    in_structure=problem_loader.context_structure,
    out_structure=problem_loader.decision_structure
)

Make sure that your model is in evaluation mode first!

[35]:
model.eval()  # Set the model in evaluation mode.
# model.train()  # To set the model in train mode.

It is able to take as input a context and return a decision.

[36]:
problem = pb_generator.generate_problem()
context, _ = problem.get_context()
decision, _ = model(context)
print(decision)
source
           features
              value
object_id
0          0.237434

It can also process batches of contexts and return batches of decisions.

[37]:
problem_batch = pb_generator.generate_problem_batch(batch_size=4)
context, _ = problem_batch.get_context()
decision, _ = model.forward_batch(graph=context)
print(decision)
source
                    features
                       value
batch_id object_id
0        0          0.294713
         1          0.000000
         2          0.000000
1        0          0.000000
         1          3.260542
         2          0.000000
2        0          0.326413
         1          0.000000
         2          0.000000
3        0         -0.784929
         1         -0.000000
         2         -0.000000

Trainer

Let us train our Graph Neural Network model over a problem loader. The core training loop is defined by the following pseudocode.

for problem_batch in problem_loader:
    context, _ = problem_batch.get_context()
    decision, _ = model.forward_batch(context)
    gradient, _ = problem_batch.get_gradient(decision)
    model.backprop(gradient)

In practice, we use energnn.trainer to implement the training logic, and allow to use :

  • optax for the optimizer,

  • orbax for checkpointing and saving/loading models.

[38]:
from energnn.trainer import Trainer
import optax

trainer = Trainer(model=model, gradient_transformation=optax.adam(learning_rate=1e-3))

The training is performed by iterating over a train loader, and the validation score is periodically computed on a validation loader.

[39]:
train_loader = LinearSystemProblemLoader(seed=7, dataset_size=64, batch_size=4, n_max=3)
val_loader = LinearSystemProblemLoader(seed=8, dataset_size=8, batch_size=4, n_max=3)
[40]:
_ = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    eval_before_training=True,
    n_epochs=3,
)
Validation: 100%|██████████| 2/2 [00:03<00:00,  1.83s/batch, score=2.1002e+00]
Epoch 1/3: 100%|██████████| 16/16 [00:22<00:00,  1.42s/batch]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.26batch/s, score=1.0779e+00]
Epoch 2/3: 100%|██████████| 16/16 [00:19<00:00,  1.21s/batch]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.11batch/s, score=9.9698e-01]
Epoch 3/3: 100%|██████████| 16/16 [00:18<00:00,  1.13s/batch]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.21batch/s, score=9.3471e-01]