tie framework works roughly

This commit is contained in:
Zhenwen Dai 2014-06-23 19:39:35 +01:00
parent fa6b5343cc
commit 567612b3a9
2 changed files with 33 additions and 20 deletions

View file

@ -423,12 +423,15 @@ class Indexable(Nameable, Observable):
if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED if np.all(self._fixes_): self._fixes_ = None # ==UNFIXED
def _connect_fixes(self): def _connect_fixes(self):
fixed_indices = self.constraints[__fixed__] from ties_and_remappings import Tie
if fixed_indices.size > 0: self._ensure_fixes()
self._ensure_fixes() # for c, ind in self.constraints.iteritems():
self._fixes_[fixed_indices] = FIXED # if c == __fixed__ or isinstance(c,Tie):
else: # self._fixes_[ind] = FIXED
self._fixes_ = None [np.put(self._fixes_, ind, FIXED) for c, ind in self.constraints.iteritems()
if c == __fixed__ or isinstance(c,Tie)]
if np.all(self._fixes_): self._fixes_ = None
if self.constraints[__fixed__]==0:
del self.constraints[__fixed__] del self.constraints[__fixed__]
#=========================================================================== #===========================================================================
@ -501,29 +504,25 @@ class Indexable(Nameable, Observable):
old_const = self.constraints.properties()[:] old_const = self.constraints.properties()[:]
self.unconstrain() self.unconstrain()
#set these parameters to be 'fixed' as in, not optimized
self._highest_parent_._set_fixed(self, self._raveled_index())
#see if a tie exists with that name #see if a tie exists with that name
if name in self._highest_parent_.ties: if name in self._highest_parent_.ties:
t = self._highest_parent_.ties[name] t = self._highest_parent_.ties[name]
else: else:
#create a tie object #create a tie object
value = np.atleast_1d(self.param_array)[0]*1 value = np.atleast_1d(self.param_array)[0]*1
import ties_and_remappings from ties_and_remappings import Tie
t = ties_and_remappings.Tie(value=value, name=name) t = Tie(value=value, name=name)
#add the new tie object to the global index #add the new tie object to the global index
self._highest_parent_.ties[name] = t self._highest_parent_.ties[name] = t
self._highest_parent_.add_parameter(t) self._highest_parent_.add_parameter(t)
#constrain the tie as we were constrained #constrain the tie as we were constrained
if len(old_const)==1: if len(old_const)==1:
t.constrain(old_const[0]) t.constrain(old_const[0])
self.constraints.add(t, self._raveled_index()) self.constraints.add(t, self._raveled_index())
t.add_tied_parameter(self) t.add_tied_parameter(self)
self._highest_parent_._connect_fixes()
def constrain(self, transform, warning=True, trigger_parent=True): def constrain(self, transform, warning=True, trigger_parent=True):
""" """
@ -649,10 +648,12 @@ class OptimizationHandlable(Indexable):
def _get_params_transformed(self): def _get_params_transformed(self):
# transformed parameters (apply un-transformation rules) # transformed parameters (apply un-transformation rules)
p = self.param_array.copy() p = self.param_array.copy()
[np.put(p, ind, c.finv(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__] from ties_and_remappings import Tie
[np.put(p, ind, c.finv(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__ and not isinstance(c,Tie)]
if self.has_parent() and self.constraints[__fixed__].size != 0: if self.has_parent() and self.constraints[__fixed__].size != 0:
fixes = np.ones(self.size).astype(bool) fixes = np.ones(self.size).astype(bool)
fixes[self.constraints[__fixed__]] = FIXED [np.put(fixes,ind,FIXED) for c, ind in self.constraints.iteritems()
if c == __fixed__ or isinstance(c,Tie)]
return p[fixes] return p[fixes]
elif self._has_fixes(): elif self._has_fixes():
return p[self._fixes_] return p[self._fixes_]
@ -664,15 +665,21 @@ class OptimizationHandlable(Indexable):
This means, the optimizer sees p, whereas the model sees transformed(p), This means, the optimizer sees p, whereas the model sees transformed(p),
such that, the parameters the model sees are in the right domain. such that, the parameters the model sees are in the right domain.
""" """
from ties_and_remappings import Tie
if not(p is self.param_array): if not(p is self.param_array):
if self.has_parent() and self.constraints[__fixed__].size != 0: if self.has_parent() and self.constraints[__fixed__].size != 0:
fixes = np.ones(self.size).astype(bool) fixes = np.ones(self.size).astype(bool)
fixes[self.constraints[__fixed__]] = FIXED # fixes[self.constraints[__fixed__]] = FIXED
for c, ind in self.constraints.iteritems():
if c == __fixed__ or isinstance(c,Tie):
fixes[ind] = FIXED
self.param_array.flat[fixes] = p self.param_array.flat[fixes] = p
elif self._has_fixes(): self.param_array.flat[self._fixes_] = p elif self._has_fixes(): self.param_array.flat[self._fixes_] = p
else: self.param_array.flat = p else: self.param_array.flat = p
[np.put(self.param_array, ind, c.f(self.param_array.flat[ind])) [np.put(self.param_array, ind, c.f(self.param_array.flat[ind]))
for c, ind in self.constraints.iteritems() if c != __fixed__] for c, ind in self.constraints.iteritems() if c != __fixed__ and not isinstance(c,Tie)]
[np.put(self.param_array, ind, c.val)
for c, ind in self.constraints.iteritems() if isinstance(c,Tie)]
self._trigger_params_changed() self._trigger_params_changed()
def _trigger_params_changed(self, trigger_parent=True): def _trigger_params_changed(self, trigger_parent=True):
@ -699,7 +706,9 @@ class OptimizationHandlable(Indexable):
""" """
if self.has_parent(): if self.has_parent():
return g return g
[np.put(g, i, g[i] * c.gradfactor(self.param_array[i])) for c, i in self.constraints.iteritems() if c != __fixed__] from ties_and_remappings import Tie
[np.put(g, self._raveled_index_for(c.val), g[i].sum()) for c, i in self.constraints.iteritems() if isinstance(c,Tie)]
[np.put(g, i, g[i] * c.gradfactor(self.param_array[i])) for c, i in self.constraints.iteritems() if c != __fixed__ and not isinstance(c,Tie)]
if self._has_fixes(): return g[self._fixes_] if self._has_fixes(): return g[self._fixes_]
return g return g

View file

@ -54,6 +54,7 @@ class Tie(Remapping):
index = self._highest_parent_.constraints[self] index = self._highest_parent_.constraints[self]
if len(index)==0: if len(index)==0:
return # nothing to tie together, this tie exists without any tied parameters return # nothing to tie together, this tie exists without any tied parameters
self.value.gradient[:] = self._highest_parent_.gradient[index].sum()
vals = self._highest_parent_.param_array[index] vals = self._highest_parent_.param_array[index]
uvals = np.unique(vals) uvals = np.unique(vals)
if len(uvals)==1: if len(uvals)==1:
@ -66,6 +67,9 @@ class Tie(Remapping):
else: else:
#more than one of the tied things changed. panic. #more than one of the tied things changed. panic.
raise ValueError, "something is wrong with the tieing" raise ValueError, "something is wrong with the tieing"
def parameters_changed(self):
super(Tie,self).parameters_changed()
self.value.gradient[:] = self._highest_parent_.gradient[self._highest_parent_.constraints[self]].sum()
def mapping(self): def mapping(self):
return self.value return self.value