-
-
Save slinderman/24552af1bdbb6cb033bfea9b2dc4ecfd to your computer and use it in GitHub Desktop.
| """ | |
| A collection of helper functions for optimization with JAX. | |
| UPDATE: This is obsolete now that `jax.scipy.optimize.minimize` is exists! | |
| """ | |
| import numpy as onp | |
| import scipy.optimize | |
| from jax import grad, jit | |
| from jax.tree_util import tree_flatten, tree_unflatten | |
| from jax.flatten_util import ravel_pytree | |
| from itertools import count | |
| def minimize(fun, x0, | |
| method=None, | |
| args=(), | |
| bounds=None, | |
| constraints=(), | |
| tol=None, | |
| callback=None, | |
| options=None): | |
| """ | |
| A simple wrapper for scipy.optimize.minimize using JAX. | |
| Args: | |
| fun: The objective function to be minimized, written in JAX code | |
| so that it is automatically differentiable. It is of type, | |
| ```fun: x, *args -> float``` | |
| where `x` is a PyTree and args is a tuple of the fixed parameters needed | |
| to completely specify the function. | |
| x0: Initial guess represented as a JAX PyTree. | |
| args: tuple, optional. Extra arguments passed to the objective function | |
| and its derivative. Must consist of valid JAX types; e.g. the leaves | |
| of the PyTree must be floats. | |
| _The remainder of the keyword arguments are inherited from | |
| `scipy.optimize.minimize`, and their descriptions are copied here for | |
| convenience._ | |
| method : str or callable, optional | |
| Type of solver. Should be one of | |
| - 'Nelder-Mead' :ref:`(see here) <optimize.minimize-neldermead>` | |
| - 'Powell' :ref:`(see here) <optimize.minimize-powell>` | |
| - 'CG' :ref:`(see here) <optimize.minimize-cg>` | |
| - 'BFGS' :ref:`(see here) <optimize.minimize-bfgs>` | |
| - 'Newton-CG' :ref:`(see here) <optimize.minimize-newtoncg>` | |
| - 'L-BFGS-B' :ref:`(see here) <optimize.minimize-lbfgsb>` | |
| - 'TNC' :ref:`(see here) <optimize.minimize-tnc>` | |
| - 'COBYLA' :ref:`(see here) <optimize.minimize-cobyla>` | |
| - 'SLSQP' :ref:`(see here) <optimize.minimize-slsqp>` | |
| - 'trust-constr':ref:`(see here) <optimize.minimize-trustconstr>` | |
| - 'dogleg' :ref:`(see here) <optimize.minimize-dogleg>` | |
| - 'trust-ncg' :ref:`(see here) <optimize.minimize-trustncg>` | |
| - 'trust-exact' :ref:`(see here) <optimize.minimize-trustexact>` | |
| - 'trust-krylov' :ref:`(see here) <optimize.minimize-trustkrylov>` | |
| - custom - a callable object (added in version 0.14.0), | |
| see below for description. | |
| If not given, chosen to be one of ``BFGS``, ``L-BFGS-B``, ``SLSQP``, | |
| depending if the problem has constraints or bounds. | |
| bounds : sequence or `Bounds`, optional | |
| Bounds on variables for L-BFGS-B, TNC, SLSQP, Powell, and | |
| trust-constr methods. There are two ways to specify the bounds: | |
| 1. Instance of `Bounds` class. | |
| 2. Sequence of ``(min, max)`` pairs for each element in `x`. None | |
| is used to specify no bound. | |
| Note that in order to use `bounds` you will need to manually flatten | |
| them in the same order as your inputs `x0`. | |
| constraints : {Constraint, dict} or List of {Constraint, dict}, optional | |
| Constraints definition (only for COBYLA, SLSQP and trust-constr). | |
| Constraints for 'trust-constr' are defined as a single object or a | |
| list of objects specifying constraints to the optimization problem. | |
| Available constraints are: | |
| - `LinearConstraint` | |
| - `NonlinearConstraint` | |
| Constraints for COBYLA, SLSQP are defined as a list of dictionaries. | |
| Each dictionary with fields: | |
| type : str | |
| Constraint type: 'eq' for equality, 'ineq' for inequality. | |
| fun : callable | |
| The function defining the constraint. | |
| jac : callable, optional | |
| The Jacobian of `fun` (only for SLSQP). | |
| args : sequence, optional | |
| Extra arguments to be passed to the function and Jacobian. | |
| Equality constraint means that the constraint function result is to | |
| be zero whereas inequality means that it is to be non-negative. | |
| Note that COBYLA only supports inequality constraints. | |
| Note that in order to use `constraints` you will need to manually flatten | |
| them in the same order as your inputs `x0`. | |
| tol : float, optional | |
| Tolerance for termination. For detailed control, use solver-specific | |
| options. | |
| options : dict, optional | |
| A dictionary of solver options. All methods accept the following | |
| generic options: | |
| maxiter : int | |
| Maximum number of iterations to perform. Depending on the | |
| method each iteration may use several function evaluations. | |
| disp : bool | |
| Set to True to print convergence messages. | |
| For method-specific options, see :func:`show_options()`. | |
| callback : callable, optional | |
| Called after each iteration. For 'trust-constr' it is a callable with | |
| the signature: | |
| ``callback(xk, OptimizeResult state) -> bool`` | |
| where ``xk`` is the current parameter vector represented as a PyTree, | |
| and ``state`` is an `OptimizeResult` object, with the same fields | |
| as the ones from the return. If callback returns True the algorithm | |
| execution is terminated. | |
| For all the other methods, the signature is: | |
| ```callback(xk)``` | |
| where `xk` is the current parameter vector, represented as a PyTree. | |
| Returns: | |
| res : The optimization result represented as a ``OptimizeResult`` object. | |
| Important attributes are: | |
| ``x``: the solution array, represented as a JAX PyTree | |
| ``success``: a Boolean flag indicating if the optimizer exited successfully | |
| ``message``: describes the cause of the termination. | |
| See `scipy.optimize.OptimizeResult` for a description of other attributes. | |
| """ | |
| # Use tree flatten and unflatten to convert params x0 from PyTrees to flat arrays | |
| x0_flat, unravel = ravel_pytree(x0) | |
| # Wrap the objective function to consume flat _original_ | |
| # numpy arrays and produce scalar outputs. | |
| def fun_wrapper(x_flat, *args): | |
| x = unravel(x_flat) | |
| return float(fun(x, *args)) | |
| # Wrap the gradient in a similar manner | |
| jac = jit(grad(fun)) | |
| def jac_wrapper(x_flat, *args): | |
| x = unravel(x_flat) | |
| g_flat, _ = ravel_pytree(jac(x, *args)) | |
| return onp.array(g_flat) | |
| # Wrap the callback to consume a pytree | |
| def callback_wrapper(x_flat, *args): | |
| if callback is not None: | |
| x = unravel(x_flat) | |
| return callback(x, *args) | |
| # Minimize with scipy | |
| results = scipy.optimize.minimize(fun_wrapper, | |
| x0_flat, | |
| args=args, | |
| method=method, | |
| jac=jac_wrapper, | |
| callback=callback_wrapper, | |
| bounds=bounds, | |
| constraints=constraints, | |
| tol=tol, | |
| options=options) | |
| # pack the output back into a PyTree | |
| results["x"] = unravel(results["x"]) | |
| return results |
It would still let you take in bounds and other options which aren't implemented in JAX yet
See https://jaxopt.github.io/ for a new library that might be useful.
Hi Slinderman, Thanks for the wrapper. I would like to ask if there is a way to make the code vmappable? Currently I can use vmap on jax.scipy.optimize.minimize. However the downside is that it only supports the BFGS algorithm. Also, the scipy minimize wrapper in jaxopt is not vmappable. When I run the code below, I get the jax conversion error.
def do_minimize(p, x, y, z, lb, ub, smf):
return minimize(cost_fun, p, args = (x, y, z, lb, ub, smf) , method = 'TNC', tol=1e-12, options = {'maxiter':20000})
sol = jax.vmap(do_minimize)(par_log, F, Y, sigma_Y, lb_mat, ub_mat, smoothing_mat)
154 # Minimize with scipy
# --> 155 results = scipy.optimize.minimize(fun_wrapper,
# 156 x0_flat,
# 157 args=args,
# ~/anaconda3/envs/simulation/lib/python3.10/site-packages/scipy/optimize/_minimize.py in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
# 494
# ...
# 4.3169071 , 4.3169071 ],
# [2.67476726, 2.67476726, 2.67476726, ..., 4.3169071 ,
# 4.3169071 , 4.3169071 ]], dtype=float64)
# batch_dim = 0
# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
I just realized that using 'list(map(func, *args)' instead of vmap works well.
Hi, I wanted to check if there's a similar wrapper for linprog? Thank you!
I would like to know if there will be a benefit of also JIT-ing the objective fun?
Context: Methods like Nelder-Mead do not use the jac.
I think this is still useful when using any method other than
BFGSsince that is the only one thatjax.scipy.optimize.minimizecurrently supports (at the time of writing this).