mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-18 13:55:14 +02:00
[var dtc missing]
This commit is contained in:
parent
c8da45aa8f
commit
8dacea2c13
1 changed files with 21 additions and 7 deletions
|
|
@ -229,18 +229,27 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
#csa = common_subarrays(inan, 1)
|
#csa = common_subarrays(inan, 1)
|
||||||
size = Y.shape[1]
|
size = Y.shape[1]
|
||||||
#logger.info('preparing subarrays {:3.3%}'.format((i+1.)/size))
|
#logger.info('preparing subarrays {:3.3%}'.format((i+1.)/size))
|
||||||
#Ys = [Y[v, :][:, ind] for v, ind in self._subarray_indices]
|
Ys = []
|
||||||
|
next_ten = [0.]
|
||||||
|
count = itertools.count()
|
||||||
|
for v, y in itertools.izip(inan.T, Y.T[:,:,None]):
|
||||||
|
i = count.next()
|
||||||
|
if ((i+1.)/size) >= next_ten[0]:
|
||||||
|
logger.info('preparing subarrays {:>6.1%}'.format((i+1.)/size))
|
||||||
|
next_ten[0] += .1
|
||||||
|
Ys.append(y[v,:])
|
||||||
|
|
||||||
next_ten = [0.]
|
next_ten = [0.]
|
||||||
count = itertools.count()
|
count = itertools.count()
|
||||||
def trace(y):
|
def trace(y):
|
||||||
i = count.next()
|
i = count.next()
|
||||||
if ((i+1.)/size) >= next_ten[0]:
|
if ((i+1.)/size) >= next_ten[0]:
|
||||||
logger.info('preparing traces {:> 3.1%}'.format((i+1.)/size))
|
logger.info('preparing traces {:>6.1%}'.format((i+1.)/size))
|
||||||
next_ten[0] += .1
|
next_ten[0] += .1
|
||||||
y = y[inan[:,i],i:i+1]
|
y = y[inan[:,i],i:i+1]
|
||||||
return np.einsum('ij,ij->', y,y)
|
return np.einsum('ij,ij->', y,y)
|
||||||
traces = [trace(Y) for _ in xrange(size)]
|
traces = [trace(Y) for _ in xrange(size)]
|
||||||
return traces
|
return Ys, traces
|
||||||
else:
|
else:
|
||||||
self._subarray_indices = [[slice(None),slice(None)]]
|
self._subarray_indices = [[slice(None),slice(None)]]
|
||||||
return [Y], [(Y**2).sum()]
|
return [Y], [(Y**2).sum()]
|
||||||
|
|
@ -257,7 +266,7 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
psi1_all = kern.K(X, Z)
|
psi1_all = kern.K(X, Z)
|
||||||
psi2_all = None
|
psi2_all = None
|
||||||
|
|
||||||
traces = self._Y(Y)
|
Ys, traces = self._Y(Y)
|
||||||
beta_all = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), 1e-6)
|
beta_all = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), 1e-6)
|
||||||
het_noise = beta_all.size != 1
|
het_noise = beta_all.size != 1
|
||||||
|
|
||||||
|
|
@ -287,9 +296,14 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
|
|
||||||
#logger.info('computing dimension-wise likelihood and derivatives')
|
#logger.info('computing dimension-wise likelihood and derivatives')
|
||||||
#size = len(Ys)
|
#size = len(Ys)
|
||||||
for i, [y, v, trYYT] in enumerate(itertools.izip(Y.T[:,:,None], self._inan.T, traces)):
|
#size = Y.shape[1]
|
||||||
y = y[v]
|
#next_ten = 0
|
||||||
#logger.info('{:.3%} dimensions:{}'.format((i+1.)/size, ind))
|
for i, [y, v, trYYT] in enumerate(itertools.izip(Ys, self._inan.T, traces)):
|
||||||
|
#if ((i+1.)/size) >= next_ten:
|
||||||
|
# logger.info('preparing traces {:> 6.1%}'.format((i+1.)/size))
|
||||||
|
# next_ten += .1
|
||||||
|
|
||||||
|
#y = y[v]
|
||||||
if het_noise: beta = beta_all[i]
|
if het_noise: beta = beta_all[i]
|
||||||
else: beta = beta_all
|
else: beta = beta_all
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue