mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
vardtc updates
This commit is contained in:
parent
8d1cae6459
commit
6b8e418597
1 changed files with 3 additions and 2 deletions
|
|
@ -179,6 +179,7 @@ class VarDTC(object):
|
|||
return post, log_marginal, grad_dict
|
||||
|
||||
class VarDTCMissingData(object):
|
||||
const_jitter = 1e-6
|
||||
def __init__(self, limit=1):
|
||||
from ...util.caching import Cacher
|
||||
self._Y = Cacher(self._subarray_computations, limit)
|
||||
|
|
@ -250,7 +251,7 @@ class VarDTCMissingData(object):
|
|||
|
||||
for y, trYYT, [v, ind] in itertools.izip(Ys, traces, self._subarray_indices):
|
||||
if het_noise: beta = beta_all[ind]
|
||||
else: beta = beta_all[0]
|
||||
else: beta = beta_all
|
||||
|
||||
VVT_factor = (beta*y)
|
||||
VVT_factor_all[v, ind].flat = VVT_factor.flat
|
||||
|
|
@ -311,7 +312,7 @@ class VarDTCMissingData(object):
|
|||
het_noise, uncertain_inputs, LB,
|
||||
_LBi_Lmi_psi1Vf, DBi_plus_BiPBi, Lm, A,
|
||||
psi0, psi1, beta,
|
||||
data_fit, num_data, output_dim, trYYT)
|
||||
data_fit, num_data, output_dim, trYYT, Y)
|
||||
|
||||
if full_VVT_factor: woodbury_vector[:, ind] = Cpsi1Vf
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue