mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-14 22:42:37 +02:00
improve numerical stability of vardtc_parallel
This commit is contained in:
parent
80adaed616
commit
77a96efeba
3 changed files with 8 additions and 7 deletions
|
|
@ -54,7 +54,7 @@ class Transformation(object):
|
||||||
class Logexp(Transformation):
|
class Logexp(Transformation):
|
||||||
domain = _POSITIVE
|
domain = _POSITIVE
|
||||||
def f(self, x):
|
def f(self, x):
|
||||||
return np.where(x>_lim_val, x, np.log(1. + np.exp(np.clip(x, -_lim_val, _lim_val)))) + epsilon
|
return np.where(x>_lim_val, x, np.log1p(np.exp(np.clip(x, -_lim_val, _lim_val)))) + epsilon
|
||||||
#raises overflow warning: return np.where(x>_lim_val, x, np.log(1. + np.exp(x)))
|
#raises overflow warning: return np.where(x>_lim_val, x, np.log(1. + np.exp(x)))
|
||||||
def finv(self, f):
|
def finv(self, f):
|
||||||
return np.where(f>_lim_val, f, np.log(np.exp(f+1e-20) - 1.))
|
return np.where(f>_lim_val, f, np.log(np.exp(f+1e-20) - 1.))
|
||||||
|
|
|
||||||
|
|
@ -170,13 +170,14 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
diag.add(Kmm, self.const_jitter)
|
diag.add(Kmm, self.const_jitter)
|
||||||
Lm = jitchol(Kmm)
|
Lm = jitchol(Kmm)
|
||||||
|
|
||||||
Lambda = Kmm+psi2_full
|
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
|
||||||
|
Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT
|
||||||
LL = jitchol(Lambda)
|
LL = jitchol(Lambda)
|
||||||
|
LL = np.dot(Lm,LL)
|
||||||
b,_ = dtrtrs(LL, psi1Y_full.T)
|
b,_ = dtrtrs(LL, psi1Y_full.T)
|
||||||
bbt = np.square(b).sum()
|
bbt = np.square(b).sum()
|
||||||
v,_ = dtrtrs(LL.T,b,lower=False)
|
v,_ = dtrtrs(LL.T,b,lower=False)
|
||||||
vvt = np.einsum('md,od->mo',v,v)
|
vvt = np.einsum('md,od->mo',v,v)
|
||||||
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
|
|
||||||
|
|
||||||
Psi2LLInvT = dtrtrs(LL,psi2_full)[0].T
|
Psi2LLInvT = dtrtrs(LL,psi2_full)[0].T
|
||||||
LmInvPsi2LLInvT= dtrtrs(Lm,Psi2LLInvT)[0]
|
LmInvPsi2LLInvT= dtrtrs(Lm,Psi2LLInvT)[0]
|
||||||
|
|
|
||||||
|
|
@ -378,7 +378,7 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
|
||||||
if self.GPU_direct:
|
if self.GPU_direct:
|
||||||
dL_dpsi1_gpu = dL_dpsi1
|
dL_dpsi1_gpu = dL_dpsi1
|
||||||
dL_dpsi2_gpu = dL_dpsi2
|
dL_dpsi2_gpu = dL_dpsi2
|
||||||
dL_dpsi0_sum = gpuarray.sum(dL_dpsi0).get()
|
dL_dpsi0_sum = dL_dpsi0.get().sum() #gpuarray.sum(dL_dpsi0).get()
|
||||||
else:
|
else:
|
||||||
dL_dpsi1_gpu = self.gpuCache['dL_dpsi1_gpu']
|
dL_dpsi1_gpu = self.gpuCache['dL_dpsi1_gpu']
|
||||||
dL_dpsi2_gpu = self.gpuCache['dL_dpsi2_gpu']
|
dL_dpsi2_gpu = self.gpuCache['dL_dpsi2_gpu']
|
||||||
|
|
@ -394,7 +394,7 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
|
||||||
self.g_psi1compDer.prepared_call((self.blocknum,1),(self.threadnum,1,1),dvar_gpu.gpudata,dl_gpu.gpudata,dZ_gpu.gpudata,dmu_gpu.gpudata,dS_gpu.gpudata,dL_dpsi1_gpu.gpudata,psi1_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
|
self.g_psi1compDer.prepared_call((self.blocknum,1),(self.threadnum,1,1),dvar_gpu.gpudata,dl_gpu.gpudata,dZ_gpu.gpudata,dmu_gpu.gpudata,dS_gpu.gpudata,dL_dpsi1_gpu.gpudata,psi1_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
|
||||||
self.g_psi2compDer.prepared_call((self.blocknum,1),(self.threadnum,1,1),dvar_gpu.gpudata,dl_gpu.gpudata,dZ_gpu.gpudata,dmu_gpu.gpudata,dS_gpu.gpudata,dL_dpsi2_gpu.gpudata,psi2n_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
|
self.g_psi2compDer.prepared_call((self.blocknum,1),(self.threadnum,1,1),dvar_gpu.gpudata,dl_gpu.gpudata,dZ_gpu.gpudata,dmu_gpu.gpudata,dS_gpu.gpudata,dL_dpsi2_gpu.gpudata,psi2n_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
|
||||||
|
|
||||||
dL_dvar = dL_dpsi0_sum + gpuarray.sum(dvar_gpu).get()
|
dL_dvar = dL_dpsi0_sum + dvar_gpu.get().sum()#gpuarray.sum(dvar_gpu).get()
|
||||||
sum_axis(grad_mu_gpu,dmu_gpu,N*Q,self.blocknum)
|
sum_axis(grad_mu_gpu,dmu_gpu,N*Q,self.blocknum)
|
||||||
dL_dmu = grad_mu_gpu.get()
|
dL_dmu = grad_mu_gpu.get()
|
||||||
sum_axis(grad_S_gpu,dS_gpu,N*Q,self.blocknum)
|
sum_axis(grad_S_gpu,dS_gpu,N*Q,self.blocknum)
|
||||||
|
|
@ -404,7 +404,7 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
|
||||||
sum_axis(grad_l_gpu,dl_gpu,Q,self.blocknum)
|
sum_axis(grad_l_gpu,dl_gpu,Q,self.blocknum)
|
||||||
dL_dlengscale = grad_l_gpu.get()
|
dL_dlengscale = grad_l_gpu.get()
|
||||||
else:
|
else:
|
||||||
dL_dlengscale = gpuarray.sum(dl_gpu).get()
|
dL_dlengscale = dl_gpu.get().sum() #gpuarray.sum(dl_gpu).get()
|
||||||
|
|
||||||
return dL_dvar, dL_dlengscale, dL_dZ, dL_dmu, dL_dS
|
return dL_dvar, dL_dlengscale, dL_dZ, dL_dmu, dL_dS
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue