manual merging with AS

This commit is contained in:
James Hensman 2015-04-16 12:45:04 +01:00
commit e88b8a88d1
16 changed files with 724 additions and 91 deletions

View file

@ -2,18 +2,37 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np
from scipy.special import cbrt
from .config import *
_lim_val = np.finfo(np.float64).max
_lim_val_exp = np.log(_lim_val)
_lim_val_square = np.sqrt(_lim_val)
_lim_val_cube = np.power(_lim_val, -3)
#_lim_val_cube = cbrt(_lim_val)
_lim_val_cube = np.nextafter(_lim_val**(1/3.0), -np.inf)
_lim_val_quad = np.nextafter(_lim_val**(1/4.0), -np.inf)
_lim_val_three_times = np.nextafter(_lim_val/3.0, -np.inf)
def safe_exp(f):
clip_f = np.clip(f, -np.inf, _lim_val_exp)
return np.exp(clip_f)
def safe_square(f):
f = np.clip(f, -np.inf, _lim_val_square)
return f**2
def safe_cube(f):
f = np.clip(f, -np.inf, _lim_val_cube)
return f**3
def safe_quad(f):
f = np.clip(f, -np.inf, _lim_val_quad)
return f**4
def safe_three_times(f):
f = np.clip(f, -np.inf, _lim_val_three_times)
return 3*f
def chain_1(df_dg, dg_dx):
"""
Generic chaining function for first derivative
@ -39,8 +58,8 @@ def chain_2(d2f_dg2, dg_dx, df_dg, d2g_dx2):
return d2f_dg2
if len(d2f_dg2) > 1 and len(d2f_dg2.shape)>1 and d2f_dg2.shape[-1] > 1:
raise NotImplementedError('Not implemented for matricies yet')
#dg_dx_2 = np.clip(dg_dx, 1e-12, _lim_val_square)**2
dg_dx_2 = dg_dx**2
dg_dx_2 = np.clip(dg_dx, -np.inf, _lim_val_square)**2
#dg_dx_2 = dg_dx**2
return d2f_dg2*(dg_dx_2) + df_dg*d2g_dx2
def chain_3(d3f_dg3, dg_dx, d2f_dg2, d2g_dx2, df_dg, d3g_dx3):
@ -55,8 +74,8 @@ def chain_3(d3f_dg3, dg_dx, d2f_dg2, d2g_dx2, df_dg, d3g_dx3):
if ( (len(d2f_dg2) > 1 and d2f_dg2.shape[-1] > 1)
or (len(d3f_dg3) > 1 and d3f_dg3.shape[-1] > 1)):
raise NotImplementedError('Not implemented for matricies yet')
#dg_dx_3 = np.clip(dg_dx, 1e-12, _lim_val_cube)**3
dg_dx_3 = dg_dx**3
dg_dx_3 = np.clip(dg_dx, -np.inf, _lim_val_cube)**3
#dg_dx_3 = dg_dx**3
return d3f_dg3*(dg_dx_3) + 3*d2f_dg2*dg_dx*d2g_dx2 + df_dg*d3g_dx3
def opt_wrapper(m, **kwargs):