mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
vardtc updates
This commit is contained in:
parent
8d1cae6459
commit
6b8e418597
1 changed files with 3 additions and 2 deletions
|
|
@ -179,6 +179,7 @@ class VarDTC(object):
|
||||||
return post, log_marginal, grad_dict
|
return post, log_marginal, grad_dict
|
||||||
|
|
||||||
class VarDTCMissingData(object):
|
class VarDTCMissingData(object):
|
||||||
|
const_jitter = 1e-6
|
||||||
def __init__(self, limit=1):
|
def __init__(self, limit=1):
|
||||||
from ...util.caching import Cacher
|
from ...util.caching import Cacher
|
||||||
self._Y = Cacher(self._subarray_computations, limit)
|
self._Y = Cacher(self._subarray_computations, limit)
|
||||||
|
|
@ -250,7 +251,7 @@ class VarDTCMissingData(object):
|
||||||
|
|
||||||
for y, trYYT, [v, ind] in itertools.izip(Ys, traces, self._subarray_indices):
|
for y, trYYT, [v, ind] in itertools.izip(Ys, traces, self._subarray_indices):
|
||||||
if het_noise: beta = beta_all[ind]
|
if het_noise: beta = beta_all[ind]
|
||||||
else: beta = beta_all[0]
|
else: beta = beta_all
|
||||||
|
|
||||||
VVT_factor = (beta*y)
|
VVT_factor = (beta*y)
|
||||||
VVT_factor_all[v, ind].flat = VVT_factor.flat
|
VVT_factor_all[v, ind].flat = VVT_factor.flat
|
||||||
|
|
@ -311,7 +312,7 @@ class VarDTCMissingData(object):
|
||||||
het_noise, uncertain_inputs, LB,
|
het_noise, uncertain_inputs, LB,
|
||||||
_LBi_Lmi_psi1Vf, DBi_plus_BiPBi, Lm, A,
|
_LBi_Lmi_psi1Vf, DBi_plus_BiPBi, Lm, A,
|
||||||
psi0, psi1, beta,
|
psi0, psi1, beta,
|
||||||
data_fit, num_data, output_dim, trYYT)
|
data_fit, num_data, output_dim, trYYT, Y)
|
||||||
|
|
||||||
if full_VVT_factor: woodbury_vector[:, ind] = Cpsi1Vf
|
if full_VVT_factor: woodbury_vector[:, ind] = Cpsi1Vf
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue