[inferenceX] with missing data

This commit is contained in:
mzwiessele 2014-11-14 12:00:23 +00:00
parent 118ed2733e
commit 1a09a39ab0

View file

@ -34,10 +34,11 @@ class InferenceX(Model):
:type Y: numpy.ndarray
"""
def __init__(self, model, Y, name='inferenceX', init='L2'):
if np.isnan(Y).any():
if np.isnan(Y).any() or getattr(model, 'missing_data', False):
assert Y.shape[0]==1, "The current implementation of inference X only support one data point at a time with missing data!"
self.missing_data = True
self.valid_dim = np.logical_not(np.isnan(Y[0]))
self.ninan = getattr(model, 'ninan', None)
else:
self.missing_data = False
super(InferenceX, self).__init__(name)
@ -109,7 +110,10 @@ class InferenceX(Model):
if self.missing_data:
wv = wv[:,self.valid_dim]
output_dim = self.valid_dim.sum()
self.dL_dpsi2 = beta*(output_dim*self.posterior.woodbury_inv - np.einsum('md,od->mo',wv, wv))/2.
if self.ninan is not None:
self.dL_dpsi2 = beta/2.*(self.posterior.woodbury_inv[:,:,self.valid_dim] - np.einsum('md,od->mo',wv, wv)[:, :, None]).sum(-1)
else:
self.dL_dpsi2 = beta/2.*(output_dim*self.posterior.woodbury_inv - np.einsum('md,od->mo',wv, wv))
self.dL_dpsi1 = beta*np.dot(self.Y[:,self.valid_dim], wv.T)
self.dL_dpsi0 = - beta/2.* np.ones(self.Y.shape[0])
else: