mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
[vardtc missing] performance fixes
This commit is contained in:
parent
66390cbf8a
commit
83e2df838d
2 changed files with 18 additions and 36 deletions
|
|
@ -874,6 +874,9 @@ class Parameterizable(OptimizationHandlable):
|
||||||
"""
|
"""
|
||||||
Array representing the parameters of this class.
|
Array representing the parameters of this class.
|
||||||
There is only one copy of all parameters in memory, two during optimization.
|
There is only one copy of all parameters in memory, two during optimization.
|
||||||
|
|
||||||
|
!WARNING!: setting the parameter array MUST always be done in memory:
|
||||||
|
m.param_array[:] = m_copy.param_array
|
||||||
"""
|
"""
|
||||||
if self.__dict__.get('_param_array_', None) is None:
|
if self.__dict__.get('_param_array_', None) is None:
|
||||||
self._param_array_ = np.empty(self.size, dtype=np.float64)
|
self._param_array_ = np.empty(self.size, dtype=np.float64)
|
||||||
|
|
|
||||||
|
|
@ -36,11 +36,11 @@ class VarDTC(LatentFunctionInference):
|
||||||
return param_to_array(np.sum(np.square(Y)))
|
return param_to_array(np.sum(np.square(Y)))
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
# has to be overridden, as Cacher objects cannot be pickled.
|
# has to be overridden, as Cacher objects cannot be pickled.
|
||||||
return self.limit
|
return self.limit
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
# has to be overridden, as Cacher objects cannot be pickled.
|
# has to be overridden, as Cacher objects cannot be pickled.
|
||||||
self.limit = state
|
self.limit = state
|
||||||
from ...util.caching import Cacher
|
from ...util.caching import Cacher
|
||||||
self.get_trYYT = Cacher(self._get_trYYT, self.limit)
|
self.get_trYYT = Cacher(self._get_trYYT, self.limit)
|
||||||
|
|
@ -203,11 +203,11 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
self._Y.limit = limit
|
self._Y.limit = limit
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
# has to be overridden, as Cacher objects cannot be pickled.
|
# has to be overridden, as Cacher objects cannot be pickled.
|
||||||
return self._Y.limit, self._inan
|
return self._Y.limit, self._inan
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
# has to be overridden, as Cacher objects cannot be pickled.
|
# has to be overridden, as Cacher objects cannot be pickled.
|
||||||
from ...util.caching import Cacher
|
from ...util.caching import Cacher
|
||||||
self.limit = state[0]
|
self.limit = state[0]
|
||||||
self._inan = state[1]
|
self._inan = state[1]
|
||||||
|
|
@ -273,21 +273,16 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
Lm = jitchol(Kmm)
|
Lm = jitchol(Kmm)
|
||||||
if uncertain_inputs: LmInv = dtrtri(Lm)
|
if uncertain_inputs: LmInv = dtrtri(Lm)
|
||||||
|
|
||||||
VVT_factor_all = np.empty(Y.shape)
|
#VVT_factor_all = np.empty(Y.shape)
|
||||||
full_VVT_factor = VVT_factor_all.shape[1] == Y.shape[1]
|
#full_VVT_factor = VVT_factor_all.shape[1] == Y.shape[1]
|
||||||
if not full_VVT_factor:
|
#if not full_VVT_factor:
|
||||||
psi1V = np.dot(Y.T*beta_all, psi1_all).T
|
# psi1V = np.dot(Y.T*beta_all, psi1_all).T
|
||||||
|
|
||||||
for y, trYYT, [v, ind] in itertools.izip(Ys, traces, self._subarray_indices):
|
for y, trYYT, [v, ind] in itertools.izip(Ys, traces, self._subarray_indices):
|
||||||
if het_noise: beta = beta_all[ind]
|
if het_noise: beta = beta_all[ind]
|
||||||
else: beta = beta_all
|
else: beta = beta_all
|
||||||
|
|
||||||
VVT_factor = (beta*y)
|
VVT_factor = (beta*y)
|
||||||
try:
|
|
||||||
VVT_factor_all[v, ind].flat = VVT_factor.flat
|
|
||||||
except ValueError:
|
|
||||||
mult = np.ravel_multi_index((v.nonzero()[0][:,None],ind[None,:]), VVT_factor_all.shape)
|
|
||||||
VVT_factor_all.flat[mult] = VVT_factor
|
|
||||||
output_dim = y.shape[1]
|
output_dim = y.shape[1]
|
||||||
|
|
||||||
psi0 = psi0_all[v]
|
psi0 = psi0_all[v]
|
||||||
|
|
@ -347,12 +342,13 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
psi0, psi1, beta,
|
psi0, psi1, beta,
|
||||||
data_fit, num_data, output_dim, trYYT, Y)
|
data_fit, num_data, output_dim, trYYT, Y)
|
||||||
|
|
||||||
if full_VVT_factor: woodbury_vector[:, ind] = Cpsi1Vf
|
#if full_VVT_factor:
|
||||||
else:
|
woodbury_vector[:, ind] = Cpsi1Vf
|
||||||
print 'foobar'
|
#else:
|
||||||
tmp, _ = dtrtrs(Lm, psi1V, lower=1, trans=0)
|
# print 'foobar'
|
||||||
tmp, _ = dpotrs(LB, tmp, lower=1)
|
# tmp, _ = dtrtrs(Lm, psi1V, lower=1, trans=0)
|
||||||
woodbury_vector[:, ind] = dtrtrs(Lm, tmp, lower=1, trans=1)[0]
|
# tmp, _ = dpotrs(LB, tmp, lower=1)
|
||||||
|
# woodbury_vector[:, ind] = dtrtrs(Lm, tmp, lower=1, trans=1)[0]
|
||||||
|
|
||||||
#import ipdb;ipdb.set_trace()
|
#import ipdb;ipdb.set_trace()
|
||||||
Bi, _ = dpotri(LB, lower=1)
|
Bi, _ = dpotri(LB, lower=1)
|
||||||
|
|
@ -376,23 +372,6 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
'dL_dKnm':dL_dpsi1_all,
|
'dL_dKnm':dL_dpsi1_all,
|
||||||
'dL_dthetaL':dL_dthetaL}
|
'dL_dthetaL':dL_dthetaL}
|
||||||
|
|
||||||
#get sufficient things for posterior prediction
|
|
||||||
#TODO: do we really want to do this in the loop?
|
|
||||||
#if not full_VVT_factor:
|
|
||||||
# print 'foobar'
|
|
||||||
# psi1V = np.dot(Y.T*beta_all, psi1_all).T
|
|
||||||
# tmp, _ = dtrtrs(Lm, psi1V, lower=1, trans=0)
|
|
||||||
# tmp, _ = dpotrs(LB_all, tmp, lower=1)
|
|
||||||
# woodbury_vector, _ = dtrtrs(Lm, tmp, lower=1, trans=1)
|
|
||||||
#import ipdb;ipdb.set_trace()
|
|
||||||
#Bi, _ = dpotri(LB_all, lower=1)
|
|
||||||
#symmetrify(Bi)
|
|
||||||
#Bi = -dpotri(LB_all, lower=1)[0]
|
|
||||||
#from ...util import diag
|
|
||||||
#diag.add(Bi, 1)
|
|
||||||
|
|
||||||
#woodbury_inv = backsub_both_sides(Lm, Bi)
|
|
||||||
|
|
||||||
post = Posterior(woodbury_inv=woodbury_inv_all, woodbury_vector=woodbury_vector, K=Kmm, mean=None, cov=None, K_chol=Lm)
|
post = Posterior(woodbury_inv=woodbury_inv_all, woodbury_vector=woodbury_vector, K=Kmm, mean=None, cov=None, K_chol=Lm)
|
||||||
|
|
||||||
return post, log_marginal, grad_dict
|
return post, log_marginal, grad_dict
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue