Source code for energnn.problem.dataset

# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

import json
import os
import pickle
from datetime import datetime

from .metadata import ProblemMetadata


[docs] class ProblemDataset(dict): """ Dictionary-like container for datasets of problem instances. Stores dataset-level metadata and a list of ProblemMetadata instances. :param name: Identifier for the dataset. :param split: Dataset split name (e.g., "train", "val", "test"). :param version: Version number of the dataset. :param instances: List of ProblemMetadata objects describing each instance. :param size: Total number of instances in the dataset. :param context_max_shape: Maximum dimensions of context graphs across instances. :param decision_max_shape: Maximum dimensions of decision graphs across instances. :param generation_date: Timestamp when the dataset was generated. :param selection_criteria: A dictionary that contains some criteria :param tags: Key-value tags associated with the dataset for grouping or filtering. """ def __init__( self, name: str, split: str, version: int, instances: list[ProblemMetadata], size: int, context_max_shape: dict, decision_max_shape: dict, generation_date: datetime, selection_criteria: dict, tags=None, ) -> None: super().__init__() if tags is None: tags = {} self["name"] = name self["split"] = split self["version"] = version self["size"] = size self["context_max_shape"] = context_max_shape self["decision_max_shape"] = decision_max_shape self["generation_date"] = generation_date self["selection_criteria"] = selection_criteria self["instances"] = instances self["tags"] = tags
[docs] def get_infos_for_feature_store(self) -> dict: """ Retrieve the dataset's information to send to the feature store. Excludes the full list of instances. :returns: A dict containing all dataset fields except "instances", with "generation_date" converted to ISO string. """ res = self.copy() res.pop("instances") res["generation_date"] = str(self.generation_date) return res
[docs] def get_locally_missing_instances(self, path: str) -> list[ProblemMetadata]: """ Identify instances whose files are missing in a local directory. :param path: Base directory where instance files should be stored. :returns: List of metadata of instances not present under `path`. """ return [instance for instance in self.instances if not os.path.exists(os.path.join(path, instance.storage_path))]
[docs] def get_instance_paths(self) -> list[str]: """ List the storage paths for all instances in the dataset. :returns: List of instance file paths as stored in metadata. """ return [instance.storage_path for instance in self.instances]
[docs] def to_json(self, file_path: str): """ Serialize the dataset to a JSON file for human-readable archives. Note: JSON output will not preserve Python types on loading. :param file_path: Target JSON file path. :raises IOError: If writing to the file system fails. """ with open(file_path, "w", encoding="utf-8") as handle: json.dump(self, handle, indent=4, ensure_ascii=False, default=str)
[docs] def to_pickle(self, file_path: str): """ Serialize the dataset to a pickle file for efficient reload. :param file_path: Target pickle file path. :raises IOError: If writing to the file system fails. """ with open(file_path, "wb") as handle: pickle.dump(self, handle, protocol=pickle.HIGHEST_PROTOCOL)
[docs] @classmethod def from_pickle(cls, file_path: str) -> "ProblemDataset": """ Load a dataset from a pickle file produced by `to_pickle`. :param file_path: Source pickle file path. :returns: Restored `ProblemDataset` instance. """ with open(file_path, "rb") as handle: dataset = pickle.load(handle) return dataset
@property def name(self) -> str: return self.get("name") @property def split(self) -> str: return self.get("split") @property def version(self) -> int: return self.get("version") @property def size(self) -> int: return self.get("size") @property def context_max_shape(self) -> dict: return self.get("context_max_shape") @property def decision_max_shape(self) -> dict: return self.get("decision_max_shape") @property def generation_date(self) -> str: return self.get("generation_date") @property def selection_criteria(self) -> dict: return self.get("selection_criteria") @property def instances(self) -> list[ProblemMetadata]: return self.get("instances") @property def tags(self) -> dict: return self.get("tags")