{ "cells": [ { "metadata": {}, "cell_type": "markdown", "source": [ "# Tutorial\n", "\n", "In this `energnn` tutorial, we will review :\n", "- How to install `energnn`.\n", "- The interaction with typical implementations of `energnn.problem.Problem`, `energnn.problem.ProblemBatch` and `energnn.problem.ProblemLoader`.\n", "- The creation of a GNN model using `energnn.model.ready_to_use`,\n", "- The training of the model with `energnn.trainer.Trainer`,\n", "- The usage of the trained model." ], "id": "ef720d484a9302d7" }, { "metadata": {}, "cell_type": "markdown", "source": [ "## Installation\n", "\n", "To install the latest stable release of `energnn` on CPU,\n", "```bash\n", "pip install energnn\n", "```\n", "For the GPU version,\n", "```bash\n", "pip install energnn --extra gpu\n", "```" ], "id": "7e9fb261104f6c9e" }, { "metadata": {}, "cell_type": "markdown", "source": [ "## Problem Class\n", "\n", "Let's consider the following use case.\n", "Knowing the pair $(A, b)$, we wish to find $x$ such that $Ax = b$.\n", "Let us generate a random problem instance and explore its interface." ], "id": "23617362f5203a69" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:13.190219843Z", "start_time": "2026-03-07T21:25:12.743521764Z" } }, "cell_type": "code", "source": [ "from energnn.problem.example import LinearSystemProblemGenerator\n", "\n", "pb_generator = LinearSystemProblemGenerator(seed=7, n_max=4)\n", "problem = pb_generator.generate_problem()" ], "id": "782b06152596c14", "outputs": [], "execution_count": 21 }, { "metadata": {}, "cell_type": "markdown", "source": [ "### Context Graph\n", "\n", "The input of our GNN model is referred to as **the context**, instantiated as an `energnn.graph.Graph` object.\n", "\n", "In this case, it is the pair $(A, b)$, framed as a Hyper Heterogeneous Multi Graph (H2MG).\n", "The matrix $A$ is represented by a hyper-edge set called **arrow**, and the vectore $b$ is represented by a hyper-edge set called **source**." ], "id": "3884df19af92c1e3" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:13.413183347Z", "start_time": "2026-03-07T21:25:13.208153134Z" } }, "cell_type": "code", "source": [ "# Let us explore the context structure\n", "print(problem.context_structure)" ], "id": "8b5eb762b52a202f", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Ports Features\n", "Name \n", "arrow [from, to] [value]\n", "source [id] [value]\n" ] } ], "execution_count": 22 }, { "metadata": {}, "cell_type": "markdown", "source": "Each hyper-edge set has ports and / or features. Ports define the connectivity of the graph and associate a hyper-edge with an integer **address**.", "id": "af7708e8d2180038" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:13.579066449Z", "start_time": "2026-03-07T21:25:13.415891907Z" } }, "cell_type": "code", "source": [ "# Print the context graph associated to the problem instance\n", "context, _ = problem.get_context()\n", "print(context)" ], "id": "2c4eb40211b3d616", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "arrow\n", " ports features\n", " from to value\n", "object_id \n", "0 0.0 1.0 1.457083\n", "1 0.0 3.0 3.200915\n", "2 1.0 2.0 2.025772\n", "3 2.0 1.0 1.216732\n", "4 3.0 2.0 0.748781\n", "source\n", " ports features\n", " id value\n", "object_id \n", "0 0.0 -0.155426\n", "1 1.0 1.770309\n", "2 2.0 -0.406632\n", "3 3.0 0.654355\n", "\n" ] } ], "execution_count": 23 }, { "metadata": {}, "cell_type": "markdown", "source": [ "The **context** of this specific problem instance has:\n", "\n", "- 5 **arrow** objects,\n", "- 4 **source** objects." ], "id": "d56da62a16b8933c" }, { "metadata": {}, "cell_type": "markdown", "source": [ "### Decision Graph\n", "The output of our GNN model is referred to as **the decision**, instantiated as an `energnn.graph.Graph` object.\n", "\n", "In this case, it is the variable $x$.\n", "\n", "This specific problem class has a helper method called `get_zero_decision` (not part of the mandatory interface),\n", "that returns a decision of the right shape and structure, filled with zeros." ], "id": "1d83e187562eb518" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:13.632723989Z", "start_time": "2026-03-07T21:25:13.583060222Z" } }, "cell_type": "code", "source": [ "# Let us explore the decision structure\n", "print(problem.decision_structure)" ], "id": "1bf664a331fa8474", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Ports Features\n", "Name \n", "source None [value]\n" ] } ], "execution_count": 24 }, { "metadata": {}, "cell_type": "markdown", "source": "Notice that decisions concern only a subset of the classes available in the context, and that they have no port and just features.", "id": "14c42bb1208cc05a" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:13.735309862Z", "start_time": "2026-03-07T21:25:13.656228528Z" } }, "cell_type": "code", "source": [ "# Print the context graph associated to the problem instance\n", "decision, _ = problem.get_zero_decision()\n", "print(decision)" ], "id": "dff64416f5fa11be", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "source\n", " features\n", " value\n", "object_id \n", "0 0.0\n", "1 -0.0\n", "2 0.0\n", "3 0.0\n", "\n" ] } ], "execution_count": 25 }, { "metadata": {}, "cell_type": "markdown", "source": [ "### Objective Function\n", "The score of a given decision is instantiated as a `float`.\n", "\n", "In this case, we use the Mean Squared Error $\\frac{1}{2} \\Vert x - x^\\star \\Vert^2$, where $x^\\star$ is the solution.\n", "\n", "We can evaluate it by injecting the zero decision we have just retrieved." ], "id": "39689e04e2473bb0" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:13.793647796Z", "start_time": "2026-03-07T21:25:13.741580204Z" } }, "cell_type": "code", "source": [ "score, _ = problem.get_score(decision=decision)\n", "print(score)" ], "id": "85ca3edb9340e2bc", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.2227775752544403\n" ] } ], "execution_count": 26 }, { "metadata": {}, "cell_type": "markdown", "source": [ "### Gradient Graph\n", "The gradient of the objective function is instantiated as an `energnn.graph.Graph` object.\n", "\n", "In this case, it is the vector $x-x^\\star$.\n", "Notice that for more complex use cases, the gradient can have more complex expressions, or even require Monte-Carlo simulations.\n", "\n", "We can evaluate it by injecting the zero decision." ], "id": "5097b7c1a288965d" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:13.887299079Z", "start_time": "2026-03-07T21:25:13.823394349Z" } }, "cell_type": "code", "source": [ "gradient, _ = problem.get_gradient(decision=decision)\n", "print(gradient)" ], "id": "93b4a8ac0de93c61", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "source\n", " features\n", " value\n", "object_id \n", "0 -0.070729\n", "1 0.334200\n", "2 -0.873894\n", "3 -0.103574\n", "\n" ] } ], "execution_count": 27 }, { "metadata": {}, "cell_type": "markdown", "source": [ "Notice that this gradient is the exact same type of object as the decision.\n", "\n", "Just as a quick sanity check, we can perform a gradient descent and make sure that the objective decreases." ], "id": "c18a4c077db1e8ef" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:14.096196263Z", "start_time": "2026-03-07T21:25:13.891205146Z" } }, "cell_type": "code", "source": [ "from energnn.graph import JaxGraph\n", "\n", "alpha = 0.5\n", "\n", "objective, _ = problem.get_score(decision=decision)\n", "print(f\"Step 0, objective = {objective}\")\n", "\n", "for i in range(10):\n", " gradient, _ = problem.get_gradient(decision=decision)\n", "\n", " # Update decision\n", " numpy_gradient = gradient.to_numpy_graph() # For now, we need to convert the gradient to a numpy graph.\n", " numpy_decision = decision.to_numpy_graph()\n", " numpy_decision.feature_flat_array -= alpha * numpy_gradient.feature_flat_array\n", " decision = JaxGraph.from_numpy_graph(numpy_decision)\n", "\n", " objective, _ = problem.get_score(decision=decision)\n", " print(f\"Step {i}, objective = {objective}\")" ], "id": "e327fb01843557b8", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 0, objective = 0.2227775752544403\n", "Step 0, objective = 0.05569439381361008\n", "Step 1, objective = 0.013923597522079945\n", "Step 2, objective = 0.003480899380519986\n", "Step 3, objective = 0.0008702246705070138\n", "Step 4, objective = 0.00021755615307483822\n", "Step 5, objective = 5.438903463073075e-05\n", "Step 6, objective = 1.3597240467788652e-05\n", "Step 7, objective = 3.399258957870188e-06\n", "Step 8, objective = 8.49789046242222e-07\n", "Step 9, objective = 2.1243486969524383e-07\n" ] } ], "execution_count": 28 }, { "metadata": {}, "cell_type": "markdown", "source": "The objective function successfully decreases! Now, let's explore how multiple problems can be batched together.", "id": "27142d6473d5fe87" }, { "metadata": {}, "cell_type": "markdown", "source": [ "## Problem Batch\n", "\n", "Interacting with a single problem instance is useful at inference time, or for debugging purposes.\n", "But to train a whole Graph Neural Network model, it is necessary to process batches of problem instances altogether." ], "id": "9a7d1de139e03db9" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:14.258402486Z", "start_time": "2026-03-07T21:25:14.111448731Z" } }, "cell_type": "code", "source": [ "from energnn.problem.example import LinearSystemProblemGenerator\n", "\n", "pb_generator = LinearSystemProblemGenerator(seed=9, n_max=3)\n", "problem_batch = pb_generator.generate_problem_batch(batch_size=3)\n", "\n", "# Let us explore the context and decision structures\n", "print(\"Context Structure:\\n\", problem_batch.context_structure, \"\\n\")\n", "print(\"Decision Structure:\\n\", problem_batch.decision_structure)" ], "id": "43f1ba98b8785caf", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Context Structure:\n", " Ports Features\n", "Name \n", "arrow [from, to] [value]\n", "source [id] [value] \n", "\n", "Decision Structure:\n", " Ports Features\n", "Name \n", "source None [value]\n" ] } ], "execution_count": 29 }, { "metadata": {}, "cell_type": "markdown", "source": "Here, contexts are still graphs, but this time with an extra dimension:", "id": "131548bb060e4d50" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:14.350313058Z", "start_time": "2026-03-07T21:25:14.280395972Z" } }, "cell_type": "code", "source": [ "context, _ = problem_batch.get_context()\n", "print(context)" ], "id": "80e2d49e249d029b", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "arrow\n", " ports features\n", " from to value\n", "batch_id object_id \n", "0 0 0.0 0.0 -1.116066\n", " 1 1.0 0.0 -0.481135\n", " 2 1.0 1.0 -1.517331\n", " 3 1.0 2.0 -0.490872\n", " 4 2.0 1.0 -0.647947\n", " 5 2.0 2.0 0.635891\n", " 6 0.0 0.0 0.000000\n", " 7 0.0 0.0 0.000000\n", " 8 0.0 0.0 0.000000\n", "1 0 0.0 0.0 -0.857040\n", " 1 0.0 1.0 1.528224\n", " 2 0.0 2.0 0.904988\n", " 3 1.0 0.0 0.541645\n", " 4 1.0 1.0 0.701052\n", " 5 1.0 2.0 -0.054635\n", " 6 2.0 0.0 0.081804\n", " 7 2.0 1.0 -1.281731\n", " 8 2.0 2.0 0.158457\n", "2 0 0.0 0.0 -0.745505\n", " 1 0.0 0.0 0.000000\n", " 2 0.0 0.0 0.000000\n", " 3 0.0 0.0 0.000000\n", " 4 0.0 0.0 0.000000\n", " 5 0.0 0.0 0.000000\n", " 6 0.0 0.0 0.000000\n", " 7 0.0 0.0 0.000000\n", " 8 0.0 0.0 0.000000\n", "source\n", " ports features\n", " id value\n", "batch_id object_id \n", "0 0 0.0 -1.942086\n", " 1 1.0 -1.634691\n", " 2 2.0 0.257661\n", "1 0 0.0 -1.926273\n", " 1 1.0 -0.110379\n", " 2 2.0 0.666537\n", "2 0 0.0 -0.104837\n", " 1 0.0 0.000000\n", " 2 0.0 0.000000\n", "\n" ] } ], "execution_count": 30 }, { "metadata": {}, "cell_type": "markdown", "source": [ "Notice that the different contexts of the batch do not have the same connectivity, and do not have the same number of **arrow** and **source** objects.\n", "To batch the different contexts together, it is thus necessary to pad them with zeros.\n", "\n", "Still, a `ProblemBatch` can be handled in a very similar way to a single `Problem`." ], "id": "90629ee6907ca540" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:14.414980538Z", "start_time": "2026-03-07T21:25:14.352290029Z" } }, "cell_type": "code", "source": [ "alpha = 0.5\n", "\n", "decision, _ = problem_batch.get_zero_decision()\n", "objective, _ = problem_batch.get_score(decision=decision)\n", "print(f\"Step 0, objective = {objective}\")\n", "\n", "for i in range(10):\n", " gradient, _ = problem_batch.get_gradient(decision=decision)\n", "\n", " # Update decision\n", " numpy_gradient = gradient.to_numpy_graph()\n", " numpy_decision = decision.to_numpy_graph()\n", " numpy_decision.feature_flat_array -= alpha * numpy_gradient.feature_flat_array\n", " decision = JaxGraph.from_numpy_graph(numpy_decision)\n", "\n", " objective, _ = problem_batch.get_score(decision=decision)\n", " print(f\"Step {i}, objective = {objective}\")" ], "id": "3eeb9881e583cfc2", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 0, objective = [1.205530047416687, 0.3518087863922119, 0.006591798271983862]\n", "Step 0, objective = [0.30138251185417175, 0.08795219659805298, 0.0016479495679959655]\n", "Step 1, objective = [0.07534561306238174, 0.021988049149513245, 0.0004119873046875]\n", "Step 2, objective = [0.018836403265595436, 0.005497012287378311, 0.000102996826171875]\n", "Step 3, objective = [0.0047090970911085606, 0.0013742543524131179, 2.574920654296875e-05]\n", "Step 4, objective = [0.0011772741563618183, 0.00034356300602667034, 6.4373016357421875e-06]\n", "Step 5, objective = [0.00029431749135255814, 8.589089702581987e-05, 1.6093254089355469e-06]\n", "Step 6, objective = [7.357935101026669e-05, 2.1472886146511883e-05, 4.023313522338867e-07]\n", "Step 7, objective = [1.8394850485492498e-05, 5.3682501857110765e-06, 1.0058283805847168e-07]\n", "Step 8, objective = [4.598733994498616e-06, 1.342021732853027e-06, 2.514570951461792e-08]\n", "Step 9, objective = [1.149680656453711e-06, 3.3547598832228687e-07, 6.28642737865448e-09]\n" ] } ], "execution_count": 31 }, { "metadata": {}, "cell_type": "markdown", "source": "Notice that there is a scores are now lists of `float`.", "id": "671f8a1c05f96e8d" }, { "metadata": {}, "cell_type": "markdown", "source": [ "## Problem Loader\n", "\n", "Being able to process problem instances per batch is nice, but not enough.\n", "To train a Graph Neural Network, we'll need to iterate over multiple minibatches of problem instances.\n", "That's where the `ProblemLoader` class comes in." ], "id": "f0c9031841eb74e5" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:14.461205586Z", "start_time": "2026-03-07T21:25:14.429191698Z" } }, "cell_type": "code", "source": [ "from energnn.problem.example import LinearSystemProblemLoader\n", "\n", "problem_loader = LinearSystemProblemLoader(batch_size=4, seed=7, dataset_size=16, n_max=4)\n", "\n", "# Let us explore the context and decision structures\n", "print(\"Context Structure:\\n\", problem_loader.context_structure, \"\\n\")\n", "print(\"Decision Structure:\\n\", problem_loader.decision_structure)" ], "id": "dcdcda287f5a5242", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Context Structure:\n", " Ports Features\n", "Name \n", "arrow [from, to] [value]\n", "source [id] [value] \n", "\n", "Decision Structure:\n", " Ports Features\n", "Name \n", "source None [value]\n" ] } ], "execution_count": 32 }, { "metadata": {}, "cell_type": "markdown", "source": "It allows to iterate over batches of problems.", "id": "a11322d96f17498f" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:14.676088380Z", "start_time": "2026-03-07T21:25:14.466461122Z" } }, "cell_type": "code", "source": [ "for problem_batch in problem_loader:\n", " context, _ = problem_batch.get_context()\n", " decision, _ = problem_batch.get_zero_decision()\n", " objective, _ = problem_batch.get_score(decision=decision)\n", " print(\"Objective:\", objective)" ], "id": "bb765a3cb4498e22", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Objective: [0.2227775752544403, 0.5266321897506714, 0.04437382146716118, 0.17635096609592438]\n", "Objective: [0.3786073327064514, 0.6252117156982422, 0.338590145111084, 0.010212971828877926]\n", "Objective: [1.4664279222488403, 0.7719758749008179, 1.2088079452514648, 0.7055290937423706]\n", "Objective: [1.4254088401794434, 2.1467173099517822, 0.37498927116394043, 0.3698219358921051]\n" ] } ], "execution_count": 33 }, { "metadata": {}, "cell_type": "markdown", "source": [ "## Graph Neural Network Model\n", "\n", "Let us instantiate a small Graph Neural Network model, that is adapted to the context and decision structure of our problem class." ], "id": "3bd3a54a84305889" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:14.945192958Z", "start_time": "2026-03-07T21:25:14.692186530Z" } }, "cell_type": "code", "source": [ "from energnn.model.ready_to_use import TinyRecurrentEquivariantGNN\n", "\n", "model = TinyRecurrentEquivariantGNN(\n", " in_structure=problem_loader.context_structure,\n", " out_structure=problem_loader.decision_structure\n", ")" ], "id": "f44dfc44f02ee1c6", "outputs": [], "execution_count": 34 }, { "metadata": {}, "cell_type": "markdown", "source": "Make sure that your model is in evaluation mode first!", "id": "b8010c4dfb17417a" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:14.979446887Z", "start_time": "2026-03-07T21:25:14.959291164Z" } }, "cell_type": "code", "source": [ "model.eval() # Set the model in evaluation mode.\n", "# model.train() # To set the model in train mode." ], "id": "6ed800f0d78e8cf", "outputs": [], "execution_count": 35 }, { "metadata": {}, "cell_type": "markdown", "source": "It is able to take as input a context and return a decision.", "id": "c8fee306c0e7f49d" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:15.424302268Z", "start_time": "2026-03-07T21:25:14.981199200Z" } }, "cell_type": "code", "source": [ "problem = pb_generator.generate_problem()\n", "context, _ = problem.get_context()\n", "decision, _ = model(context)\n", "print(decision)" ], "id": "8da9ee7100536a1c", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "source\n", " features\n", " value\n", "object_id \n", "0 0.237434\n", "\n" ] } ], "execution_count": 36 }, { "metadata": {}, "cell_type": "markdown", "source": "It can also process batches of contexts and return batches of decisions.", "id": "97d0fdd1cf01adf3" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:16.116544703Z", "start_time": "2026-03-07T21:25:15.427519405Z" } }, "cell_type": "code", "source": [ "problem_batch = pb_generator.generate_problem_batch(batch_size=4)\n", "context, _ = problem_batch.get_context()\n", "decision, _ = model.forward_batch(graph=context)\n", "print(decision)" ], "id": "e9d737e5f9867f55", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "source\n", " features\n", " value\n", "batch_id object_id \n", "0 0 0.294713\n", " 1 0.000000\n", " 2 0.000000\n", "1 0 0.000000\n", " 1 3.260542\n", " 2 0.000000\n", "2 0 0.326413\n", " 1 0.000000\n", " 2 0.000000\n", "3 0 -0.784929\n", " 1 -0.000000\n", " 2 -0.000000\n", "\n" ] } ], "execution_count": 37 }, { "metadata": {}, "cell_type": "markdown", "source": [ "## Trainer\n", "\n", "Let us train our Graph Neural Network model over a problem loader. The core training loop is defined by the following pseudocode.\n", "\n", "```python\n", "for problem_batch in problem_loader:\n", " context, _ = problem_batch.get_context()\n", " decision, _ = model.forward_batch(context)\n", " gradient, _ = problem_batch.get_gradient(decision)\n", " model.backprop(gradient)\n", "```\n", "\n", "In practice, we use `energnn.trainer` to implement the training logic, and allow to use :\n", "\n", "- `optax` for the optimizer,\n", "- `orbax` for checkpointing and saving/loading models." ], "id": "96198c25077f4af0" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:16.326885924Z", "start_time": "2026-03-07T21:25:16.210450574Z" } }, "cell_type": "code", "source": [ "from energnn.trainer import Trainer\n", "import optax\n", "\n", "trainer = Trainer(model=model, gradient_transformation=optax.adam(learning_rate=1e-3))" ], "id": "2770b4ca8179091b", "outputs": [], "execution_count": 38 }, { "metadata": {}, "cell_type": "markdown", "source": "The training is performed by iterating over a **train** loader, and the validation score is periodically computed on a **validation** loader.", "id": "bd42d36691734ce6" }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:25:16.349184757Z", "start_time": "2026-03-07T21:25:16.328423260Z" } }, "cell_type": "code", "source": [ "train_loader = LinearSystemProblemLoader(seed=7, dataset_size=64, batch_size=4, n_max=3)\n", "val_loader = LinearSystemProblemLoader(seed=8, dataset_size=8, batch_size=4, n_max=3)" ], "id": "4116cb6fd07a79ef", "outputs": [], "execution_count": 39 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-07T21:26:25.272519279Z", "start_time": "2026-03-07T21:25:16.350903194Z" } }, "cell_type": "code", "source": [ "_ = trainer.train(\n", " train_loader=train_loader,\n", " val_loader=val_loader,\n", " eval_before_training=True,\n", " n_epochs=3,\n", ")" ], "id": "db53c590d19776ed", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Validation: 100%|██████████| 2/2 [00:03<00:00, 1.83s/batch, score=2.1002e+00]\n", "Epoch 1/3: 100%|██████████| 16/16 [00:22<00:00, 1.42s/batch]\n", "Validation: 100%|██████████| 2/2 [00:01<00:00, 1.26batch/s, score=1.0779e+00]\n", "Epoch 2/3: 100%|██████████| 16/16 [00:19<00:00, 1.21s/batch]\n", "Validation: 100%|██████████| 2/2 [00:01<00:00, 1.11batch/s, score=9.9698e-01]\n", "Epoch 3/3: 100%|██████████| 16/16 [00:18<00:00, 1.13s/batch]\n", "Validation: 100%|██████████| 2/2 [00:01<00:00, 1.21batch/s, score=9.3471e-01]\n" ] } ], "execution_count": 40 } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 5 }