Source code for energnn.problem.problem
# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
from abc import ABC, abstractmethod
from energnn.graph import GraphStructure, JaxGraph
from energnn.problem.metadata import ProblemMetadata
[docs]
class Problem(ABC):
"""
Base abstract class for graph-based optimization or learning problems.
Subclasses must implement methods to retrieve the problem context graph,
an initial zero decision graph, compute gradients, evaluate score,
and provide problem metadata.
Notes:
- All returned Graph objects must adhere to the energnn.graph.Graph API.
- Methods returning tuples will return additional information in the dict when
`get_info=True` for tracking purpose.
"""
[docs]
@abstractmethod
def __init__(self):
"""
Initialize the problem instance.
This constructor may accept parameters specific to the problem definition,
such as hyperparameters, or graph dimensions.
:raises NotImplementedError: If the subclass does not override this constructor.
"""
raise NotImplementedError
[docs]
@abstractmethod
def get_context(self, get_info: bool = False, step: int | None = None) -> tuple[JaxGraph, dict]:
"""
Retrieve the context graph math:`x` of the problem instance.
The context graph encompasses all fixed inputs required to define
the instance, such as node features, hyper-edge indices, and any static attributes.
:param get_info: Flag indicating if additional information should be returned for tracking purpose.
:param step: Training step number passed by the trainer. Useful for scheduling.
:return: A tuple containing:
- **Graph**: The context graph object.
- **dict**: A dictionary of additional information (empty if `get_info=False`).
:raises NotImplementedError: If the subclass does not override this constructor.
"""
raise NotImplementedError
[docs]
@abstractmethod
def get_gradient(self, *, decision: JaxGraph, get_info: bool = False, step: int | None = None) -> tuple[JaxGraph, dict]:
r"""
Compute the gradient graph :math:`\nabla_y f` for a given decision :math:`y`.
The gradient guides optimization algorithms such as gradient descent.
:param decision: A decision graph at which to evaluate the gradient.
:param get_info: Flag indicating if additional information should be returned for tracking purpose.
:param step: Training step number passed by the trainer. Useful for scheduling.
:return: A tuple containing:
- **Graph**: The gradient graph with the same structure as decision.
- **dict**: A dictionary of additional information (empty if `get_info=False`).
:raises NotImplementedError: If the subclass does not override this constructor.
"""
raise NotImplementedError
[docs]
@abstractmethod
def get_score(self, *, decision: JaxGraph, get_info: bool = False, step: int | None = None) -> tuple[float, dict]:
"""Should return a scalar `score` that evaluates the decision graph :math:`y`.
:param decision: The decision graph to evaluate.
:param get_info: Flag indicating if additional information should be returned for tracking purpose.
:param step: Training step number passed by the trainer. Useful for scheduling.
:return: A tuple containing:
- **float**: A float as score value.
- **dict**: A dictionary of additional information (empty if `get_info=False`).
:raises NotImplementedError: If the subclass does not override this constructor.
"""
raise NotImplementedError
@abstractmethod
def get_metadata(self) -> ProblemMetadata:
"""
Retrieve metadata describing problem characteristics.
Metadata include problem name, configuration ID, version, context shape, decision shape.
:return: A ProblemMetadata instance encapsulating metadata.
:raises NotImplementedError: if subclass does not override this constructor.
"""
raise NotImplementedError
@abstractmethod
def save(self, *, path: str) -> None:
"""
Serialize the problem instance to disk.
This method should make all necessary states persist to reconstruct
the problem later.
:param path: Filesystem path or directory to save problem data.
:raises NotImplementedError: If the subclass does not override this constructor.
"""
raise NotImplementedError
@property
@abstractmethod
def context_structure(self) -> GraphStructure:
"""Should define the structure of all context graphs."""
raise NotImplementedError
@property
@abstractmethod
def decision_structure(self) -> GraphStructure:
"""Should define the structure of all decision graphs."""
raise NotImplementedError