Merging with private repo, mostly fixed

This commit is contained in:
Alan Saul 2015-03-27 14:17:03 +00:00
parent 6a1de2bfc2
commit 0ea3d33695
8 changed files with 768 additions and 318 deletions

View file

@ -4,6 +4,16 @@
import numpy as np
from config import *
_lim_val = np.finfo(np.float64).max
_lim_val_exp = np.log(_lim_val)
_lim_val_square = np.sqrt(_lim_val)
_lim_val_cube = np.power(_lim_val, -3)
def safe_exp(f):
clip_f = np.clip(f, -np.inf, _lim_val_exp)
return np.exp(clip_f)
def chain_1(df_dg, dg_dx):
"""
Generic chaining function for first derivative
@ -11,6 +21,11 @@ def chain_1(df_dg, dg_dx):
.. math::
\\frac{d(f . g)}{dx} = \\frac{df}{dg} \\frac{dg}{dx}
"""
if np.all(dg_dx==1.):
return df_dg
if len(df_dg) > 1 and df_dg.shape[-1] > 1:
import ipdb; ipdb.set_trace() # XXX BREAKPOINT
raise NotImplementedError('Not implemented for matricies yet')
return df_dg * dg_dx
def chain_2(d2f_dg2, dg_dx, df_dg, d2g_dx2):
@ -20,7 +35,13 @@ def chain_2(d2f_dg2, dg_dx, df_dg, d2g_dx2):
.. math::
\\frac{d^{2}(f . g)}{dx^{2}} = \\frac{d^{2}f}{dg^{2}}(\\frac{dg}{dx})^{2} + \\frac{df}{dg}\\frac{d^{2}g}{dx^{2}}
"""
return d2f_dg2*(dg_dx**2) + df_dg*d2g_dx2
if np.all(dg_dx==1.) and np.all(d2g_dx2 == 0):
return d2f_dg2
if len(d2f_dg2) > 1 and d2f_dg2.shape[-1] > 1:
raise NotImplementedError('Not implemented for matricies yet')
#dg_dx_2 = np.clip(dg_dx, 1e-12, _lim_val_square)**2
dg_dx_2 = dg_dx**2
return d2f_dg2*(dg_dx_2) + df_dg*d2g_dx2
def chain_3(d3f_dg3, dg_dx, d2f_dg2, d2g_dx2, df_dg, d3g_dx3):
"""
@ -29,11 +50,18 @@ def chain_3(d3f_dg3, dg_dx, d2f_dg2, d2g_dx2, df_dg, d3g_dx3):
.. math::
\\frac{d^{3}(f . g)}{dx^{3}} = \\frac{d^{3}f}{dg^{3}}(\\frac{dg}{dx})^{3} + 3\\frac{d^{2}f}{dg^{2}}\\frac{dg}{dx}\\frac{d^{2}g}{dx^{2}} + \\frac{df}{dg}\\frac{d^{3}g}{dx^{3}}
"""
return d3f_dg3*(dg_dx**3) + 3*d2f_dg2*dg_dx*d2g_dx2 + df_dg*d3g_dx3
if np.all(dg_dx==1.) and np.all(d2g_dx2==0) and np.all(d3g_dx3==0):
return d3f_dg3
if ( (len(d2f_dg2) > 1 and d2f_dg2.shape[-1] > 1)
or (len(d3f_dg3) > 1 and d3f_dg3.shape[-1] > 1)):
raise NotImplementedError('Not implemented for matricies yet')
#dg_dx_3 = np.clip(dg_dx, 1e-12, _lim_val_cube)**3
dg_dx_3 = dg_dx**3
return d3f_dg3*(dg_dx_3) + 3*d2f_dg2*dg_dx*d2g_dx2 + df_dg*d3g_dx3
def opt_wrapper(m, **kwargs):
"""
This function just wraps the optimization procedure of a GPy
Thit function just wraps the optimization procedure of a GPy
object so that optimize() pickleable (necessary for multiprocessing).
"""
m.optimize(**kwargs)
@ -96,3 +124,47 @@ from :class:ndarray)"""
if len(param) == 1:
return param[0].view(np.ndarray)
return [x.view(np.ndarray) for x in param]
def blockify_hessian(func):
def wrapper_func(self, *args, **kwargs):
# Invoke the wrapped function first
retval = func(self, *args, **kwargs)
# Now do something here with retval and/or action
if self.not_block_really and (retval.shape[0] != retval.shape[1]):
return np.diagflat(retval)
else:
return retval
return wrapper_func
def blockify_third(func):
def wrapper_func(self, *args, **kwargs):
# Invoke the wrapped function first
retval = func(self, *args, **kwargs)
# Now do something here with retval and/or action
if self.not_block_really and (len(retval.shape) < 3):
num_data = retval.shape[0]
d3_block_cache = np.zeros((num_data, num_data, num_data))
diag_slice = range(num_data)
d3_block_cache[diag_slice, diag_slice, diag_slice] = np.squeeze(retval)
return d3_block_cache
else:
return retval
return wrapper_func
def blockify_dhess_dtheta(func):
def wrapper_func(self, *args, **kwargs):
# Invoke the wrapped function first
retval = func(self, *args, **kwargs)
# Now do something here with retval and/or action
if self.not_block_really and (len(retval.shape) < 3):
num_data = retval.shape[0]
num_params = retval.shape[-1]
dhess_dtheta = np.zeros((num_data, num_data, num_params))
diag_slice = range(num_data)
for param_ind in range(num_params):
dhess_dtheta[diag_slice, diag_slice, param_ind] = np.squeeze(retval[:,param_ind])
return dhess_dtheta
else:
return retval
return wrapper_func