mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 21:42:39 +02:00
fix the SSGPLVM with MPI
This commit is contained in:
parent
08ed72b2f2
commit
cf33808673
7 changed files with 28 additions and 30 deletions
|
|
@ -425,9 +425,6 @@ class Indexable(Nameable, Observable):
|
|||
def _connect_fixes(self):
|
||||
from ties_and_remappings import Tie
|
||||
self._ensure_fixes()
|
||||
# for c, ind in self.constraints.iteritems():
|
||||
# if c == __fixed__ or isinstance(c,Tie):
|
||||
# self._fixes_[ind] = FIXED
|
||||
[np.put(self._fixes_, ind, FIXED) for c, ind in self.constraints.iteritems()
|
||||
if c == __fixed__ or isinstance(c,Tie)]
|
||||
if np.all(self._fixes_): self._fixes_ = None
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class Tie(Remapping):
|
|||
uvals = np.unique(vals)
|
||||
if len(uvals)==1:
|
||||
#all of the tied things are at the same value
|
||||
if (self.value==uvals[0]).all():
|
||||
if np.all(self.value==uvals[0]):
|
||||
return # DO NOT DO ANY CHANGES IF THE TIED PART IS NOT CHANGED!
|
||||
self.value[...] = uvals[0]
|
||||
elif len(uvals)==2:
|
||||
|
|
@ -72,7 +72,7 @@ class Tie(Remapping):
|
|||
def parameters_changed(self):
|
||||
#ensure all out parameters have the correct value, as specified by our mapping
|
||||
index = self._highest_parent_.constraints[self]
|
||||
if (self._highest_parent_.param_array[index]==self.value).all():
|
||||
if np.all(self._highest_parent_.param_array[index]==self.value):
|
||||
return # STOP TRIGGER THE UPDATE LOOP MULTIPLE TIMES!!!
|
||||
self._highest_parent_.param_array[index] = self.mapping()
|
||||
[p.notify_observers(which=self) for p in self.tied_parameters]
|
||||
|
|
|
|||
|
|
@ -150,6 +150,9 @@ class SpikeAndSlabPosterior(VariationalPosterior):
|
|||
n.parameters[dc['mean']._parent_index_] = dc['mean']
|
||||
n.parameters[dc['variance']._parent_index_] = dc['variance']
|
||||
n.parameters[dc['binary_prob']._parent_index_] = dc['binary_prob']
|
||||
n._gradient_array_ = None
|
||||
oversize = self.size - self.mean.size - self.variance.size
|
||||
n.size = n.mean.size + n.variance.size + oversize
|
||||
n.ndim = n.mean.ndim
|
||||
n.shape = n.mean.shape
|
||||
n.num_data = n.mean.shape[0]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue