new functions mrd init_X update

This commit is contained in:
Max Zwiessele 2013-04-15 10:57:27 +01:00
parent cc32825a4a
commit c63beddcf0
4 changed files with 115 additions and 47 deletions

View file

@ -8,7 +8,6 @@ from GPy.models.Bayesian_GPLVM import Bayesian_GPLVM
import numpy
from GPy.models.sparse_GP import sparse_GP
import itertools
from matplotlib import pyplot
import pylab
from GPy.util.linalg import PCA
@ -59,6 +58,8 @@ class MRD(model):
if kwargs.has_key('init'):
init = kwargs['init']
del kwargs['init']
else:
init = "PCA"
try:
self.Q = kwargs["Q"]
except KeyError:
@ -68,11 +69,11 @@ class MRD(model):
except KeyError:
self.M = 10
X = self._init_X(Ylist, init)
self._init = True
X = self._init_X(init, Ylist)
Z = numpy.random.permutation(X.copy())[:self.M]
self.bgplvms = [Bayesian_GPLVM(Y, kernel=k(), X=X, Z=Z, **kwargs) for Y in Ylist]
del self._init
self.gref = self.bgplvms[0]
nparams = numpy.array([0] + [sparse_GP._get_params(g).size - g.Z.size for g in self.bgplvms])
@ -84,6 +85,41 @@ class MRD(model):
model.__init__(self) # @UndefinedVariable
@property
def X(self):
return self.gref.X
@X.setter
def X(self, X):
try:
self.propagate_param(X=X)
except AttributeError:
if not self._init:
raise AttributeError("bgplvm list not initialized")
@property
def Ylist(self):
return [g.likelihood.Y for g in self.bgplvms]
@Ylist.setter
def Ylist(self, Ylist):
for g, Y in itertools.izip(self.bgplvms, Ylist):
g.likelihood.Y = Y
@property
def auto_scale_factor(self):
"""
set auto_scale_factor for all gplvms
:param b: auto_scale_factor
:type b:
"""
return self.gref.auto_scale_factor
@auto_scale_factor.setter
def auto_scale_factor(self, b):
self.propagate_param(auto_scale_factor=b)
def propagate_param(self, **kwargs):
for key, val in kwargs.iteritems():
for g in self.bgplvms:
g.__setattr__(key, val)
def _get_param_names(self):
# X_names = sum([['X_%i_%i' % (n, q) for q in range(self.Q)] for n in range(self.N)], [])
# S_names = sum([['X_variance_%i_%i' % (n, q) for q in range(self.Q)] for n in range(self.N)], [])
@ -129,11 +165,11 @@ class MRD(model):
def _set_params(self, x):
start = 0; end = self.NQ
X = x[start:end].reshape(self.N, self.Q)
X = x[start:end]
start = end; end += start
X_var = x[start:end].reshape(self.N, self.Q)
X_var = x[start:end]
start = end; end += self.MQ
Z = x[start:end].reshape(self.M, self.Q)
Z = x[start:end]
thetas = x[end:]
if self._debug:
@ -144,10 +180,14 @@ class MRD(model):
# set params for all:
for g, s, e in itertools.izip(self.bgplvms, self.nparams, self.nparams[1:]):
self._set_var_params(g, X, X_var, Z)
self._set_kern_params(g, thetas[s:e].copy())
g._compute_kernel_matrices()
g._computations()
g._set_params(numpy.hstack([X, X_var, Z, thetas[s:e]]))
# self._set_var_params(g, X, X_var, Z)
# self._set_kern_params(g, thetas[s:e].copy())
# g._compute_kernel_matrices()
# if self.auto_scale_factor:
# g.scale_factor = numpy.sqrt(g.psi2.sum(0).mean() * g.likelihood.precision)
# # self.scale_factor = numpy.sqrt(self.psi2.sum(0).mean() * self.likelihood.precision)
# g._computations()
def log_likelihood(self):
@ -171,15 +211,18 @@ class MRD(model):
partial=g.partial_for_likelihood)]) \
for g in self.bgplvms])))
def _init_X(self, Ylist, init='PCA_concat'):
if init in "PCA_concat":
X = PCA(numpy.hstack(Ylist), self.Q)[0]
elif init in "PCA_single":
def _init_X(self, init='PCA', Ylist=None):
if Ylist is None:
Ylist = self.Ylist
if init in "PCA_single":
X = numpy.zeros((Ylist[0].shape[0], self.Q))
for qs, Y in itertools.izip(numpy.array_split(numpy.arange(self.Q), len(Ylist)), Ylist):
X[:, qs] = PCA(Y, len(qs))[0]
elif init in "PCA_concat":
X = PCA(numpy.hstack(Ylist), self.Q)[0]
else: # init == 'random':
X = numpy.random.randn(Ylist[0].shape[0], self.Q)
self.X = X
return X
def plot_X(self):
@ -229,7 +272,7 @@ class MRD(model):
def _debug_optimize(self, opt='scg', maxiters=500, itersteps=10):
iters = 0
optstep = lambda: self.optimize(opt, messages=1, max_iters=itersteps)
optstep = lambda: self.optimize(opt, messages=1, max_f_eval=itersteps)
self._debug_plot()
raw_input("enter to start debug")
while iters < maxiters: