mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
modified mrd with MZ
This commit is contained in:
parent
c2b7936ebd
commit
5567464968
1 changed files with 34 additions and 2 deletions
|
|
@ -41,8 +41,40 @@ class MRD(model):
|
||||||
:param kernel:
|
:param kernel:
|
||||||
kernel to use
|
kernel to use
|
||||||
"""
|
"""
|
||||||
#TODO allow different kernels for different outputs
|
def __init__(self,likelihood_list,Q,M=10,names=None,kernels=None,initX='PCA',initz='permute',_debug=False, **kwargs):
|
||||||
#def __init__(self, *Ylist, **kwargs):
|
if names is None:
|
||||||
|
self.names = ["{}".format(i + 1) for i in range(len(likelihood_list))]
|
||||||
|
|
||||||
|
#sort out the kernels
|
||||||
|
if kernels is None:
|
||||||
|
kernels = [None]*len(likelihood_list)
|
||||||
|
elif isinstance(kernels,kern.kern):
|
||||||
|
kernels = [kernels.copy() for i in range(len(likelihood_list))]
|
||||||
|
else:
|
||||||
|
assert len(kernels)==len(likelihood_list), "need one kernel per output"
|
||||||
|
assert all([isinstance(k, kern.kern) for k in kernels]), "invalid kernel object detected!"
|
||||||
|
|
||||||
|
self.Q = Q
|
||||||
|
self.M = M
|
||||||
|
self.N = self.gref.N
|
||||||
|
self.NQ = self.N * self.Q
|
||||||
|
self.MQ = self.M * self.Q
|
||||||
|
|
||||||
|
self._init = True
|
||||||
|
X = self._init_X(initx, likelihood_list)
|
||||||
|
Z = self._init_Z(initz, X)
|
||||||
|
self.bgplvms = [Bayesian_GPLVM(l, k, X=X, Z=Z, M=self.M, **kwargs) for l,k in zip(likelihood_list,kernels)]
|
||||||
|
|
||||||
|
del self._init
|
||||||
|
|
||||||
|
self.gref = self.bgplvms[0]
|
||||||
|
nparams = numpy.array([0] + [sparse_GP._get_params(g).size - g.Z.size for g in self.bgplvms])
|
||||||
|
self.nparams = nparams.cumsum()
|
||||||
|
|
||||||
|
model.__init__(self) # @UndefinedVariable
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, *likelihood_list, **kwargs):
|
def __init__(self, *likelihood_list, **kwargs):
|
||||||
if kwargs.has_key("_debug"):
|
if kwargs.has_key("_debug"):
|
||||||
self._debug = kwargs['_debug']
|
self._debug = kwargs['_debug']
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue