Source code for kdescent.descent

from functools import partial

import scipy.optimize
import tqdm.auto as tqdm
import numpy as np
import jax.random
import jax.numpy as jnp
import jaxopt
import optax

from . import keygen


[docs] def adam(lossfunc, guess, nsteps=100, param_bounds=None, learning_rate=0.01, randkey=1, const_randkey=False, **other_kwargs): """ Perform gradient descent Parameters ---------- lossfunc : callable Function to be minimized via gradient descent. Must be compatible with jax.jit and jax.grad. Must have signature f(params, **other_kwargs) guess : array-like The starting parameters. nsteps : int, optional Number of gradient descent iterations to perform, by default 100 param_bounds : Sequence, optional Lower and upper bounds of each parameter of "shape" (ndim, 2). Pass `None` as the bound for each unbounded parameter, by default None learning_rate : float, optional Initial Adam learning rate, by default 0.05 randkey : int, optional Random seed or key, by default 1. If not None, lossfunc must accept the "randkey" keyword argument, e.g. `lossfunc(params, randkey=key)` const_randkey : bool, optional By default (False), randkey is regenerated at each gradient descent iteration. Remove this behavior by setting const_randkey=True Returns ------- params : jnp.array List of params throughout the entire gradient descent, of shape (nsteps, n_param) """ if param_bounds is None: return adam_unbounded( lossfunc, guess, nsteps, learning_rate, randkey, const_randkey, **other_kwargs) assert len(guess) == len(param_bounds) if hasattr(param_bounds, "tolist"): param_bounds = param_bounds.tolist() param_bounds = [b if b is None else tuple(b) for b in param_bounds] def ulossfunc(uparams, *args, **kwargs): params = apply_inverse_transforms(uparams, param_bounds) return lossfunc(params, *args, **kwargs) init_uparams = apply_transforms(guess, param_bounds) uparams = adam_unbounded( ulossfunc, init_uparams, nsteps, learning_rate, randkey, const_randkey, **other_kwargs) params = apply_inverse_transforms(uparams.T, param_bounds).T return params
def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01, randkey=1, const_randkey=False, **other_kwargs): kwargs = {**other_kwargs} if randkey is not None: randkey = keygen.init_randkey(randkey) randkey, key_i = jax.random.split(randkey) kwargs["randkey"] = key_i if const_randkey: randkey = None opt = optax.adam(learning_rate) solver = jaxopt.OptaxSolver(opt=opt, fun=lossfunc, maxiter=nsteps) state = solver.init_state(guess, **kwargs) params = [guess] for _ in tqdm.trange(nsteps, desc="Adam Gradient Descent Progress"): if randkey is not None: randkey, key_i = jax.random.split(randkey) kwargs["randkey"] = key_i params_i, state = solver.update(params[-1], state, **kwargs) params.append(params_i) return jnp.array(params)
[docs] def bfgs(lossfunc, guess, maxsteps=100, param_bounds=None, randkey=None): """ Run BFGS to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity must be held fixed via a random key Parameters ---------- lossfunc : callable Function to be minimized via gradient descent. Must be compatible with jax.jit and jax.grad. Must have signature f(params, **other_kwargs) guess : array-like The starting parameters. maxsteps : int, optional The maximum number of steps to take, by default 100. param_bounds : Sequence, optional Lower and upper bounds of each parameter of "shape" (ndim, 2). Pass `None` as the bound for each unbounded parameter, by default None randkey : int | PRNG Key, optional Since BFGS requires a deterministic function, this key will be passed to `calc_loss_and_grad_from_params()` as the "randkey" kwarg as a constant at every iteration, by default None Returns ------- OptimizeResult (contains the following attributes): message : str describes reason of termination success : boolean True if converged fun : float minimum loss found x : array parameters at minimum loss found jac : array gradient of loss at minimum loss found nfev : int number of function evaluations nit : int number of gradient descent iterations """ kwargs = {} if randkey is not None: randkey = keygen.init_randkey(randkey) kwargs["randkey"] = randkey pbar = tqdm.trange(maxsteps, desc="BFGS Gradient Descent Progress") def callback(*_args, **_kwargs): pbar.update() loss_and_grad_fn = jax.value_and_grad( lambda x: lossfunc(x, **kwargs)) results = scipy.optimize.minimize( loss_and_grad_fn, x0=guess, method="L-BFGS-B", jac=True, options=dict(maxiter=maxsteps), callback=callback, bounds=param_bounds) pbar.close() return results
def apply_transforms(params, bounds): return jnp.array([transform(param, bound) for param, bound in zip(params, bounds)]) def apply_inverse_transforms(uparams, bounds): return jnp.array([inverse_transform(uparam, bound) for uparam, bound in zip(uparams, bounds)]) @partial(jax.jit, static_argnums=[1]) def transform(param, bounds): """Transform param into unbound param""" if bounds is None: return param low, high = bounds low_is_finite = low is not None and np.isfinite(low) high_is_finite = high is not None and np.isfinite(high) if low_is_finite and high_is_finite: mid = (high + low) / 2.0 scale = (high - low) / jnp.pi return scale * jnp.tan((param - mid) / scale) elif low_is_finite: return param - low + 1.0 / (low - param) elif high_is_finite: return param - high + 1.0 / (high - param) else: return param @partial(jax.jit, static_argnums=[1]) def inverse_transform(uparam, bounds): """Transform unbound param back into param""" if bounds is None: return uparam low, high = bounds low_is_finite = low is not None and np.isfinite(low) high_is_finite = high is not None and np.isfinite(high) if low_is_finite and high_is_finite: mid = (high + low) / 2.0 scale = (high - low) / jnp.pi return mid + scale * jnp.arctan(uparam / scale) elif low_is_finite: return 0.5 * (2.0 * low + uparam + jnp.sqrt(uparam**2 + 4)) elif high_is_finite: return 0.5 * (2.0 * high + uparam - jnp.sqrt(uparam**2 + 4)) else: return uparam