Source code for energnn.problem.loader
# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
from abc import ABC, abstractmethod
from typing import Iterator, Sized
from energnn.graph import GraphStructure
from .batch import ProblemBatch
[docs]
class ProblemLoader(ABC, Sized, Iterator[ProblemBatch]):
"""
Abstract base class for problem loaders that yield batches of problem instances.
Iterates over problem instances in batches, optionally shuffling the dataset.
:param batch_size: Number of instances per batch returned by the iterator.
:param shuffle: If true, randomly shuffle the dataset.
"""
[docs]
@abstractmethod
def __init__(self, batch_size: int, shuffle: bool = False):
"""
Initialize the problem loader.
:raises NotImplementedError: If the subclass does not override this constructor.
"""
raise NotImplementedError
[docs]
@abstractmethod
def __iter__(self) -> Iterator[ProblemBatch]:
"""
Return the loader iterator.
Should optionally reshuffle data if `shuffle=True`.
:returns: Iterator over batches.
"""
raise NotImplementedError
[docs]
@abstractmethod
def __next__(self) -> ProblemBatch:
"""
Retrieve the next batch of problems.
:returns: A `ProblemBatch` containing up to `batch_size` problem instances.
:raises StopIteration: If there are no further items.
:raises NotImplementedError: If the subclass does not override this constructor.
"""
raise NotImplementedError
[docs]
@abstractmethod
def __len__(self) -> int:
"""
Number of batches per epoch.
:raises NotImplementedError: If the subclass does not override this constructor.
"""
raise NotImplementedError
@property
@abstractmethod
def context_structure(self) -> GraphStructure:
"""Should define the structure of all context graphs."""
raise NotImplementedError
@property
@abstractmethod
def decision_structure(self) -> GraphStructure:
"""Should define the structure of all decision graphs."""
raise NotImplementedError