mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
[GPU] psi varDTC ready
This commit is contained in:
parent
5e01b94d37
commit
f776db45df
3 changed files with 8 additions and 6 deletions
|
|
@ -302,18 +302,19 @@ def update_gradients(model):
|
||||||
while not isEnd:
|
while not isEnd:
|
||||||
isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, model.X, model.Z, model.likelihood, model.Y)
|
isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, model.X, model.Z, model.likelihood, model.Y)
|
||||||
if isinstance(model.X, VariationalPosterior):
|
if isinstance(model.X, VariationalPosterior):
|
||||||
|
X_slice = model.X[n_range[0]:n_range[1]]
|
||||||
|
|
||||||
#gradients w.r.t. kernel
|
#gradients w.r.t. kernel
|
||||||
model.kern.update_gradients_expectations(variational_posterior=model.X[n_range[0]:n_range[1]], Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'])
|
model.kern.update_gradients_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'])
|
||||||
kern_grad += model.kern.gradient
|
kern_grad += model.kern.gradient
|
||||||
|
|
||||||
#gradients w.r.t. Z
|
#gradients w.r.t. Z
|
||||||
model.Z.gradient[:,model.kern.active_dims] += model.kern.gradients_Z_expectations(
|
model.Z.gradient[:,model.kern.active_dims] += model.kern.gradients_Z_expectations(
|
||||||
grad_dict['dL_dpsi1'], grad_dict['dL_dpsi2'], Z=model.Z, variational_posterior=model.X[n_range[0]:n_range[1]])
|
grad_dict['dL_dpsi1'], grad_dict['dL_dpsi2'], Z=model.Z, variational_posterior=X_slice)
|
||||||
|
|
||||||
#gradients w.r.t. posterior parameters of X
|
#gradients w.r.t. posterior parameters of X
|
||||||
X_grad = model.kern.gradients_qX_expectations(variational_posterior=model.X[n_range[0]:n_range[1]], Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'])
|
X_grad = model.kern.gradients_qX_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'])
|
||||||
model.set_X_gradients(model.X[n_range[0]:n_range[1]], X_grad)
|
model.set_X_gradients(X_slice, X_grad)
|
||||||
|
|
||||||
if het_noise:
|
if het_noise:
|
||||||
dL_dthetaL[n_range[0]:n_range[1]] = grad_dict['dL_dthetaL']
|
dL_dthetaL[n_range[0]:n_range[1]] = grad_dict['dL_dthetaL']
|
||||||
|
|
|
||||||
|
|
@ -328,7 +328,7 @@ class PSICOMP_SSRBF(object):
|
||||||
self.gpuCacheAll = None
|
self.gpuCacheAll = None
|
||||||
self.gpuCache = None
|
self.gpuCache = None
|
||||||
|
|
||||||
@Cache_this(limit=1)
|
@Cache_this(limit=1,ignore_args=(0,))
|
||||||
def psicomputations(self, variance, lengthscale, Z, mu, S, gamma):
|
def psicomputations(self, variance, lengthscale, Z, mu, S, gamma):
|
||||||
"""Compute Psi statitsitcs"""
|
"""Compute Psi statitsitcs"""
|
||||||
if isinstance(lengthscale, np.ndarray) and len(lengthscale)>1:
|
if isinstance(lengthscale, np.ndarray) and len(lengthscale)>1:
|
||||||
|
|
@ -373,7 +373,7 @@ class PSICOMP_SSRBF(object):
|
||||||
|
|
||||||
return psi0_gpu, psi1_gpu, psi2_gpu
|
return psi0_gpu, psi1_gpu, psi2_gpu
|
||||||
|
|
||||||
@Cache_this(limit=1)
|
@Cache_this(limit=1,ignore_args=(0,))
|
||||||
def _psiDercomputations(self, variance, lengthscale, Z, mu, S, gamma):
|
def _psiDercomputations(self, variance, lengthscale, Z, mu, S, gamma):
|
||||||
"""Compute the derivatives w.r.t. Psi statistics"""
|
"""Compute the derivatives w.r.t. Psi statistics"""
|
||||||
N, M, Q = mu.shape[0],Z.shape[0], mu.shape[1]
|
N, M, Q = mu.shape[0],Z.shape[0], mu.shape[1]
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,7 @@ class Cacher(object):
|
||||||
#first make sure the depth limit isn't exceeded
|
#first make sure the depth limit isn't exceeded
|
||||||
if len(self.cached_inputs) == self.limit:
|
if len(self.cached_inputs) == self.limit:
|
||||||
args_ = self.cached_inputs.pop(0)
|
args_ = self.cached_inputs.pop(0)
|
||||||
|
args_ = [a for i,a in enumerate(args_) if i not in self.ignore_args and i not in self.force_kwargs]
|
||||||
[a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None]
|
[a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None]
|
||||||
self.inputs_changed.pop(0)
|
self.inputs_changed.pop(0)
|
||||||
self.cached_outputs.pop(0)
|
self.cached_outputs.pop(0)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue