Problem.get_gradient¶
- abstract Problem.get_gradient(*, decision, get_info=False, step=None)[source]¶
Compute the gradient graph \(\nabla_y f\) for a given decision \(y\).
The gradient guides optimization algorithms such as gradient descent.
- Parameters:
decision (JaxGraph) – A decision graph at which to evaluate the gradient.
get_info (bool) – Flag indicating if additional information should be returned for tracking purpose.
step (int | None) – Training step number passed by the trainer. Useful for scheduling.
- Returns:
A tuple containing: - Graph: The gradient graph with the same structure as decision. - dict: A dictionary of additional information (empty if get_info=False).
- Raises:
NotImplementedError – If the subclass does not override this constructor.
- Return type:
tuple[JaxGraph, dict]