mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 21:42:39 +02:00
fix the problem of multiple ties on the same param array object
This commit is contained in:
parent
567612b3a9
commit
3f36a245d1
5 changed files with 18 additions and 24 deletions
|
|
@ -500,8 +500,9 @@ class Indexable(Nameable, Observable):
|
|||
#===========================================================================
|
||||
|
||||
def tie(self, name):
|
||||
from ties_and_remappings import Tie
|
||||
#remove any constraints
|
||||
old_const = self.constraints.properties()[:]
|
||||
old_const = [c for c in self.constraints.properties() if not isinstance(c,Tie)]
|
||||
self.unconstrain()
|
||||
|
||||
#see if a tie exists with that name
|
||||
|
|
@ -510,14 +511,14 @@ class Indexable(Nameable, Observable):
|
|||
else:
|
||||
#create a tie object
|
||||
value = np.atleast_1d(self.param_array)[0]*1
|
||||
from ties_and_remappings import Tie
|
||||
t = Tie(value=value, name=name)
|
||||
|
||||
#add the new tie object to the global index
|
||||
self._highest_parent_.ties[name] = t
|
||||
self._highest_parent_.add_parameter(t)
|
||||
|
||||
#constrain the tie as we were constrained
|
||||
if len(old_const)==1:
|
||||
if len(old_const)>0:
|
||||
t.constrain(old_const[0])
|
||||
|
||||
self.constraints.add(t, self._raveled_index())
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class Tie(Remapping):
|
|||
def callback(self, param=None, which=None):
|
||||
"""
|
||||
This gets called whenever any of the tied parameters changes. we spend
|
||||
considerable effort working out whhat has changed ant to what value.
|
||||
considerable effort working out what has changed and to what value.
|
||||
Then we store that value in self.value, and broadcast it everywhere
|
||||
with parameters_changed.
|
||||
"""
|
||||
|
|
@ -54,11 +54,13 @@ class Tie(Remapping):
|
|||
index = self._highest_parent_.constraints[self]
|
||||
if len(index)==0:
|
||||
return # nothing to tie together, this tie exists without any tied parameters
|
||||
self.value.gradient[:] = self._highest_parent_.gradient[index].sum()
|
||||
self.collate_gradient()
|
||||
vals = self._highest_parent_.param_array[index]
|
||||
uvals = np.unique(vals)
|
||||
if len(uvals)==1:
|
||||
#all of the tied things are at the same value
|
||||
if (self.value==uvals[0]).all():
|
||||
return # DO NOT DO ANY CHANGES IF THE TIED PART IS NOT CHANGED!
|
||||
self.value[...] = uvals[0]
|
||||
elif len(uvals)==2:
|
||||
#only *one* of the tied things has changed. it must be different to self.value
|
||||
|
|
@ -69,7 +71,7 @@ class Tie(Remapping):
|
|||
raise ValueError, "something is wrong with the tieing"
|
||||
def parameters_changed(self):
|
||||
super(Tie,self).parameters_changed()
|
||||
self.value.gradient[:] = self._highest_parent_.gradient[self._highest_parent_.constraints[self]].sum()
|
||||
self.collate_gradient()
|
||||
|
||||
def mapping(self):
|
||||
return self.value
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ class SpikeAndSlabPrior(VariationalPrior):
|
|||
self.pi = Param('pi', pi, Logistic(1e-10,1.-1e-10))
|
||||
self.variance = Param('variance',variance)
|
||||
self.add_parameters(self.pi)
|
||||
self.group_spike_prob = False
|
||||
|
||||
def KL_divergence(self, variational_posterior):
|
||||
mu = variational_posterior.mean
|
||||
|
|
@ -56,11 +55,7 @@ class SpikeAndSlabPrior(VariationalPrior):
|
|||
S = variational_posterior.variance
|
||||
gamma = variational_posterior.binary_prob
|
||||
|
||||
if self.group_spike_prob:
|
||||
gamma_grad = np.log((1-self.pi)/self.pi*gamma/(1.-gamma))+(np.square(mu)+S-np.log(S)-1.)/2.
|
||||
gamma.gradient -= gamma_grad.mean(axis=0)
|
||||
else:
|
||||
gamma.gradient -= np.log((1-self.pi)/self.pi*gamma/(1.-gamma))+(np.square(mu)+S-np.log(S)-1.)/2.
|
||||
gamma.gradient -= np.log((1-self.pi)/self.pi*gamma/(1.-gamma))+(np.square(mu)+S-np.log(S)-1.)/2.
|
||||
mu.gradient -= gamma*mu
|
||||
S.gradient -= (1. - (1. / (S))) * gamma /2.
|
||||
self.pi.gradient = (gamma/self.pi - (1.-gamma)/(1.-self.pi)).sum(axis=0)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue