[vardtc missing] performance fixes

This commit is contained in:
mzwiessele 2014-06-27 08:03:45 -07:00
parent 66390cbf8a
commit 83e2df838d
2 changed files with 18 additions and 36 deletions

View file

@ -874,6 +874,9 @@ class Parameterizable(OptimizationHandlable):
""" """
Array representing the parameters of this class. Array representing the parameters of this class.
There is only one copy of all parameters in memory, two during optimization. There is only one copy of all parameters in memory, two during optimization.
!WARNING!: setting the parameter array MUST always be done in memory:
m.param_array[:] = m_copy.param_array
""" """
if self.__dict__.get('_param_array_', None) is None: if self.__dict__.get('_param_array_', None) is None:
self._param_array_ = np.empty(self.size, dtype=np.float64) self._param_array_ = np.empty(self.size, dtype=np.float64)

View file

@ -36,11 +36,11 @@ class VarDTC(LatentFunctionInference):
return param_to_array(np.sum(np.square(Y))) return param_to_array(np.sum(np.square(Y)))
def __getstate__(self): def __getstate__(self):
# has to be overridden, as Cacher objects cannot be pickled. # has to be overridden, as Cacher objects cannot be pickled.
return self.limit return self.limit
def __setstate__(self, state): def __setstate__(self, state):
# has to be overridden, as Cacher objects cannot be pickled. # has to be overridden, as Cacher objects cannot be pickled.
self.limit = state self.limit = state
from ...util.caching import Cacher from ...util.caching import Cacher
self.get_trYYT = Cacher(self._get_trYYT, self.limit) self.get_trYYT = Cacher(self._get_trYYT, self.limit)
@ -203,11 +203,11 @@ class VarDTCMissingData(LatentFunctionInference):
self._Y.limit = limit self._Y.limit = limit
def __getstate__(self): def __getstate__(self):
# has to be overridden, as Cacher objects cannot be pickled. # has to be overridden, as Cacher objects cannot be pickled.
return self._Y.limit, self._inan return self._Y.limit, self._inan
def __setstate__(self, state): def __setstate__(self, state):
# has to be overridden, as Cacher objects cannot be pickled. # has to be overridden, as Cacher objects cannot be pickled.
from ...util.caching import Cacher from ...util.caching import Cacher
self.limit = state[0] self.limit = state[0]
self._inan = state[1] self._inan = state[1]
@ -273,21 +273,16 @@ class VarDTCMissingData(LatentFunctionInference):
Lm = jitchol(Kmm) Lm = jitchol(Kmm)
if uncertain_inputs: LmInv = dtrtri(Lm) if uncertain_inputs: LmInv = dtrtri(Lm)
VVT_factor_all = np.empty(Y.shape) #VVT_factor_all = np.empty(Y.shape)
full_VVT_factor = VVT_factor_all.shape[1] == Y.shape[1] #full_VVT_factor = VVT_factor_all.shape[1] == Y.shape[1]
if not full_VVT_factor: #if not full_VVT_factor:
psi1V = np.dot(Y.T*beta_all, psi1_all).T # psi1V = np.dot(Y.T*beta_all, psi1_all).T
for y, trYYT, [v, ind] in itertools.izip(Ys, traces, self._subarray_indices): for y, trYYT, [v, ind] in itertools.izip(Ys, traces, self._subarray_indices):
if het_noise: beta = beta_all[ind] if het_noise: beta = beta_all[ind]
else: beta = beta_all else: beta = beta_all
VVT_factor = (beta*y) VVT_factor = (beta*y)
try:
VVT_factor_all[v, ind].flat = VVT_factor.flat
except ValueError:
mult = np.ravel_multi_index((v.nonzero()[0][:,None],ind[None,:]), VVT_factor_all.shape)
VVT_factor_all.flat[mult] = VVT_factor
output_dim = y.shape[1] output_dim = y.shape[1]
psi0 = psi0_all[v] psi0 = psi0_all[v]
@ -347,12 +342,13 @@ class VarDTCMissingData(LatentFunctionInference):
psi0, psi1, beta, psi0, psi1, beta,
data_fit, num_data, output_dim, trYYT, Y) data_fit, num_data, output_dim, trYYT, Y)
if full_VVT_factor: woodbury_vector[:, ind] = Cpsi1Vf #if full_VVT_factor:
else: woodbury_vector[:, ind] = Cpsi1Vf
print 'foobar' #else:
tmp, _ = dtrtrs(Lm, psi1V, lower=1, trans=0) # print 'foobar'
tmp, _ = dpotrs(LB, tmp, lower=1) # tmp, _ = dtrtrs(Lm, psi1V, lower=1, trans=0)
woodbury_vector[:, ind] = dtrtrs(Lm, tmp, lower=1, trans=1)[0] # tmp, _ = dpotrs(LB, tmp, lower=1)
# woodbury_vector[:, ind] = dtrtrs(Lm, tmp, lower=1, trans=1)[0]
#import ipdb;ipdb.set_trace() #import ipdb;ipdb.set_trace()
Bi, _ = dpotri(LB, lower=1) Bi, _ = dpotri(LB, lower=1)
@ -376,23 +372,6 @@ class VarDTCMissingData(LatentFunctionInference):
'dL_dKnm':dL_dpsi1_all, 'dL_dKnm':dL_dpsi1_all,
'dL_dthetaL':dL_dthetaL} 'dL_dthetaL':dL_dthetaL}
#get sufficient things for posterior prediction
#TODO: do we really want to do this in the loop?
#if not full_VVT_factor:
# print 'foobar'
# psi1V = np.dot(Y.T*beta_all, psi1_all).T
# tmp, _ = dtrtrs(Lm, psi1V, lower=1, trans=0)
# tmp, _ = dpotrs(LB_all, tmp, lower=1)
# woodbury_vector, _ = dtrtrs(Lm, tmp, lower=1, trans=1)
#import ipdb;ipdb.set_trace()
#Bi, _ = dpotri(LB_all, lower=1)
#symmetrify(Bi)
#Bi = -dpotri(LB_all, lower=1)[0]
#from ...util import diag
#diag.add(Bi, 1)
#woodbury_inv = backsub_both_sides(Lm, Bi)
post = Posterior(woodbury_inv=woodbury_inv_all, woodbury_vector=woodbury_vector, K=Kmm, mean=None, cov=None, K_chol=Lm) post = Posterior(woodbury_inv=woodbury_inv_all, woodbury_vector=woodbury_vector, K=Kmm, mean=None, cov=None, K_chol=Lm)
return post, log_marginal, grad_dict return post, log_marginal, grad_dict