ProblemBatch.get_gradient

abstract ProblemBatch.get_gradient(*, decision, get_info=False, step=None)[source]

Compute gradients \(\nabla_y f\) for a batched of decision graphs \(y\).

Parameters:
  • decision (JaxGraph) – Batched decision graph at which to evaluate gradient.

  • get_info (bool) – Flag indicating if additional information should be returned for tracking purpose.

  • step (int | None) – Training step number passed by the trainer. Useful for scheduling.

Returns:

A tuple of: - Graph: A batched context object. - dict: A dictionary of additional information (empty if get_info=False).

Raises:

NotImplementedError – If the subclass does not override this constructor.

Return type:

tuple[JaxGraph, dict]