make_tf_fun

safe_learning.utilities.make_tf_fun(return_type, gradient=None, stateful=True)

Convert a python function to a tensorflow function.

Parameters:
return_type : list

A list of tensorflow return types. Needs to match with the gradient.

gradient : callable, optional

A function that provides the gradient. It takes op and one gradient per output of the function as inputs and returns one gradient for each input of the function. If stateful is False then tensorflow does not seem to compute gradients at all.

Returns:
A tensorflow function with gradients registered.