mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-05 01:32:40 +02:00
merged params here
This commit is contained in:
commit
dab35dcbb0
13 changed files with 220 additions and 412 deletions
|
|
@ -63,14 +63,15 @@ class SpikeAndSlabPrior(VariationalPrior):
|
|||
|
||||
|
||||
class VariationalPosterior(Parameterized):
|
||||
def __init__(self, means=None, variances=None, name=None, **kw):
|
||||
super(VariationalPosterior, self).__init__(name=name, **kw)
|
||||
def __init__(self, means=None, variances=None, name=None, *a, **kw):
|
||||
super(VariationalPosterior, self).__init__(name=name, *a, **kw)
|
||||
self.mean = Param("mean", means)
|
||||
self.variance = Param("variance", variances, Logexp())
|
||||
self.add_parameters(self.mean, self.variance)
|
||||
self.ndim = self.mean.ndim
|
||||
self.shape = self.mean.shape
|
||||
self.num_data, self.input_dim = self.mean.shape
|
||||
self.add_parameters(self.mean, self.variance)
|
||||
self.num_data, self.input_dim = self.mean.shape
|
||||
if self.has_uncertain_inputs():
|
||||
assert self.variance.shape == self.mean.shape, "need one variance per sample and dimenion"
|
||||
|
||||
|
|
@ -78,17 +79,23 @@ class VariationalPosterior(Parameterized):
|
|||
return not self.variance is None
|
||||
|
||||
def __getitem__(self, s):
|
||||
import copy
|
||||
n = self.__new__(self.__class__)
|
||||
dc = copy.copy(self.__dict__)
|
||||
dc['mean'] = dc['mean'][s]
|
||||
dc['variance'] = dc['variance'][s]
|
||||
dc['shape'] = dc['mean'].shape
|
||||
dc['ndim'] = dc['ndim']
|
||||
dc['num_data'], dc['input_dim'] = self.mean.shape[0], self.mean.shape[1] if dc['ndim'] > 1 else 1
|
||||
n.__dict__ = dc
|
||||
return n
|
||||
|
||||
if isinstance(s, (int, slice, tuple, list, np.ndarray)):
|
||||
import copy
|
||||
n = self.__new__(self.__class__, self.name)
|
||||
dc = self.__dict__.copy()
|
||||
dc['mean'] = self.mean[s]
|
||||
dc['variance'] = self.variance[s]
|
||||
dc['_parameters_'] = copy.copy(self._parameters_)
|
||||
n.__dict__.update(dc)
|
||||
n._parameters_[dc['mean']._parent_index_] = dc['mean']
|
||||
n._parameters_[dc['variance']._parent_index_] = dc['variance']
|
||||
n.ndim = n.mean.ndim
|
||||
n.shape = n.mean.shape
|
||||
n.num_data = n.mean.shape[0]
|
||||
n.input_dim = n.mean.shape[1] if n.ndim != 1 else 1
|
||||
return n
|
||||
else:
|
||||
return super(VariationalPrior, self).__getitem__(s)
|
||||
|
||||
class NormalPosterior(VariationalPosterior):
|
||||
'''
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue