Source code for energnn.problem.loader

# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

from abc import ABC, abstractmethod
from typing import Iterator, Sized

from energnn.graph import GraphStructure
from .batch import ProblemBatch


[docs] class ProblemLoader(ABC, Sized, Iterator[ProblemBatch]): """ Abstract base class for problem loaders that yield batches of problem instances. Iterates over problem instances in batches, optionally shuffling the dataset. :param batch_size: Number of instances per batch returned by the iterator. :param shuffle: If true, randomly shuffle the dataset. """
[docs] @abstractmethod def __init__(self, batch_size: int, shuffle: bool = False): """ Initialize the problem loader. :raises NotImplementedError: If the subclass does not override this constructor. """ raise NotImplementedError
[docs] @abstractmethod def __iter__(self) -> Iterator[ProblemBatch]: """ Return the loader iterator. Should optionally reshuffle data if `shuffle=True`. :returns: Iterator over batches. """ raise NotImplementedError
[docs] @abstractmethod def __next__(self) -> ProblemBatch: """ Retrieve the next batch of problems. :returns: A `ProblemBatch` containing up to `batch_size` problem instances. :raises StopIteration: If there are no further items. :raises NotImplementedError: If the subclass does not override this constructor. """ raise NotImplementedError
[docs] @abstractmethod def __len__(self) -> int: """ Number of batches per epoch. :raises NotImplementedError: If the subclass does not override this constructor. """ raise NotImplementedError
@property @abstractmethod def context_structure(self) -> GraphStructure: """Should define the structure of all context graphs.""" raise NotImplementedError @property @abstractmethod def decision_structure(self) -> GraphStructure: """Should define the structure of all decision graphs.""" raise NotImplementedError