mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
general tidying in models
This commit is contained in:
parent
eeb5f59fca
commit
7190e0e6bb
5 changed files with 47 additions and 45 deletions
|
|
@ -49,18 +49,6 @@ class BayesianGPLVM(SparseGP, GPLVM):
|
|||
SparseGP.__init__(self, X, likelihood, kernel, Z=Z, X_variance=X_variance, **kwargs)
|
||||
self.ensure_default_constraints()
|
||||
|
||||
def getstate(self):
|
||||
"""
|
||||
Get the current state of the class,
|
||||
here just all the indices, rest can get recomputed
|
||||
"""
|
||||
return SparseGP.getstate(self) + [self.init]
|
||||
|
||||
def setstate(self, state):
|
||||
self._const_jitter = None
|
||||
self.init = state.pop()
|
||||
SparseGP.setstate(self, state)
|
||||
|
||||
def _get_param_names(self):
|
||||
X_names = sum([['X_%i_%i' % (n, q) for q in range(self.input_dim)] for n in range(self.num_data)], [])
|
||||
S_names = sum([['X_variance_%i_%i' % (n, q) for q in range(self.input_dim)] for n in range(self.num_data)], [])
|
||||
|
|
@ -285,6 +273,19 @@ class BayesianGPLVM(SparseGP, GPLVM):
|
|||
fig.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95))
|
||||
return fig
|
||||
|
||||
def getstate(self):
|
||||
"""
|
||||
Get the current state of the class,
|
||||
here just all the indices, rest can get recomputed
|
||||
"""
|
||||
return SparseGP.getstate(self) + [self.init]
|
||||
|
||||
def setstate(self, state):
|
||||
self._const_jitter = None
|
||||
self.init = state.pop()
|
||||
SparseGP.setstate(self, state)
|
||||
|
||||
|
||||
def latent_cost_and_grad(mu_S, kern, Z, dL_dpsi0, dL_dpsi1, dL_dpsi2):
|
||||
"""
|
||||
objective function for fitting the latent variables for test points
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import pylab as pb
|
|||
import sys, pdb
|
||||
from ..core import GP
|
||||
from ..models import GPLVM
|
||||
from ..mappings import *
|
||||
from ..mappings import Kernel
|
||||
|
||||
|
||||
class BCGPLVM(GPLVM):
|
||||
|
|
|
|||
|
|
@ -39,5 +39,3 @@ class GPRegression(GP):
|
|||
|
||||
def setstate(self, state):
|
||||
return GP.setstate(self, state)
|
||||
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -44,12 +44,6 @@ class GPLVM(GP):
|
|||
Xr[:PC.shape[0], :PC.shape[1]] = PC
|
||||
return Xr
|
||||
|
||||
def getstate(self):
|
||||
return GP.getstate(self)
|
||||
|
||||
def setstate(self, state):
|
||||
GP.setstate(self, state)
|
||||
|
||||
def _get_param_names(self):
|
||||
return sum([['X_%i_%i' % (n, q) for q in range(self.input_dim)] for n in range(self.num_data)], []) + GP._get_param_names(self)
|
||||
|
||||
|
|
@ -68,7 +62,7 @@ class GPLVM(GP):
|
|||
def jacobian(self,X):
|
||||
target = np.zeros((X.shape[0],X.shape[1],self.output_dim))
|
||||
for i in range(self.output_dim):
|
||||
target[:,:,i]=self.kern.dK_dX(np.dot(self.Ki,self.likelihood.Y[:,i])[None, :],X,self.X)
|
||||
target[:,:,i] = self.kern.dK_dX(np.dot(self.Ki,self.likelihood.Y[:,i])[None, :],X,self.X)
|
||||
return target
|
||||
|
||||
def magnification(self,X):
|
||||
|
|
@ -91,3 +85,11 @@ class GPLVM(GP):
|
|||
|
||||
def plot_magnification(self, *args, **kwargs):
|
||||
return util.plot_latent.plot_magnification(self, *args, **kwargs)
|
||||
|
||||
def getstate(self):
|
||||
return GP.getstate(self)
|
||||
|
||||
def setstate(self, state):
|
||||
GP.setstate(self, state)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -81,29 +81,6 @@ class MRD(Model):
|
|||
Model.__init__(self)
|
||||
self.ensure_default_constraints()
|
||||
|
||||
def getstate(self):
|
||||
return Model.getstate(self) + [self.names,
|
||||
self.bgplvms,
|
||||
self.gref,
|
||||
self.nparams,
|
||||
self.input_dim,
|
||||
self.num_inducing,
|
||||
self.num_data,
|
||||
self.NQ,
|
||||
self.MQ]
|
||||
|
||||
def setstate(self, state):
|
||||
self.MQ = state.pop()
|
||||
self.NQ = state.pop()
|
||||
self.num_data = state.pop()
|
||||
self.num_inducing = state.pop()
|
||||
self.input_dim = state.pop()
|
||||
self.nparams = state.pop()
|
||||
self.gref = state.pop()
|
||||
self.bgplvms = state.pop()
|
||||
self.names = state.pop()
|
||||
Model.setstate(self, state)
|
||||
|
||||
@property
|
||||
def X(self):
|
||||
return self.gref.X
|
||||
|
|
@ -371,4 +348,28 @@ class MRD(Model):
|
|||
pylab.draw()
|
||||
fig.tight_layout()
|
||||
|
||||
def getstate(self):
|
||||
return Model.getstate(self) + [self.names,
|
||||
self.bgplvms,
|
||||
self.gref,
|
||||
self.nparams,
|
||||
self.input_dim,
|
||||
self.num_inducing,
|
||||
self.num_data,
|
||||
self.NQ,
|
||||
self.MQ]
|
||||
|
||||
def setstate(self, state):
|
||||
self.MQ = state.pop()
|
||||
self.NQ = state.pop()
|
||||
self.num_data = state.pop()
|
||||
self.num_inducing = state.pop()
|
||||
self.input_dim = state.pop()
|
||||
self.nparams = state.pop()
|
||||
self.gref = state.pop()
|
||||
self.bgplvms = state.pop()
|
||||
self.names = state.pop()
|
||||
Model.setstate(self, state)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue