mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
[updates] updated update structure immensely
This commit is contained in:
parent
4543fc3480
commit
ef05f49b8b
5 changed files with 99 additions and 53 deletions
|
|
@ -49,7 +49,7 @@ class Param(Parameterizable, ObsAr):
|
|||
obj._realshape_ = obj.shape
|
||||
obj._realsize_ = obj.size
|
||||
obj._realndim_ = obj.ndim
|
||||
obj._original_ = True
|
||||
obj._original_ = obj
|
||||
return obj
|
||||
|
||||
def __init__(self, name, input_array, default_constraint=None, *a, **kw):
|
||||
|
|
@ -124,10 +124,10 @@ class Param(Parameterizable, ObsAr):
|
|||
#if not reduce(lambda a, b: a or numpy.any(b is Ellipsis), s, False) and len(s) <= self.ndim:
|
||||
# s += (Ellipsis,)
|
||||
new_arr = super(Param, self).__getitem__(s, *args, **kwargs)
|
||||
try:
|
||||
try:
|
||||
new_arr._current_slice_ = s
|
||||
new_arr._gradient_array_ = self.gradient[s]
|
||||
new_arr._original_ = self.base is new_arr.base
|
||||
new_arr._original_ = self._original_
|
||||
except AttributeError: pass # returning 0d array or float, double etc
|
||||
return new_arr
|
||||
|
||||
|
|
@ -157,29 +157,29 @@ class Param(Parameterizable, ObsAr):
|
|||
return self.constraints[__fixed__].size == self.size
|
||||
|
||||
def _get_original(self, param):
|
||||
return self
|
||||
return self._original_
|
||||
|
||||
#===========================================================================
|
||||
# Pickling and copying
|
||||
#===========================================================================
|
||||
def copy(self):
|
||||
return Parameterizable.copy(self, which=self)
|
||||
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
s = self.__new__(self.__class__, name=self.name, input_array=self.view(numpy.ndarray).copy())
|
||||
memo[id(self)] = s
|
||||
memo[id(self)] = s
|
||||
import copy
|
||||
Pickleable.__setstate__(s, copy.deepcopy(self.__getstate__(), memo))
|
||||
return s
|
||||
def _setup_observers(self):
|
||||
"""
|
||||
Setup the default observers
|
||||
|
||||
|
||||
1: pass through to parent, if present
|
||||
"""
|
||||
if self.has_parent():
|
||||
self.add_observer(self._parent_, self._parent_._pass_through_notify_observers, -np.inf)
|
||||
|
||||
|
||||
#===========================================================================
|
||||
# Printing -> done
|
||||
#===========================================================================
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import numpy as np
|
|||
import re
|
||||
import logging
|
||||
|
||||
__updated__ = '2014-05-21'
|
||||
__updated__ = '2014-09-22'
|
||||
|
||||
class HierarchyError(Exception):
|
||||
"""
|
||||
|
|
@ -63,7 +63,7 @@ class Observable(object):
|
|||
@updates.setter
|
||||
def updates(self, ups):
|
||||
raise DeprecationWarning("updates is now a function, see update(True|False|None)")
|
||||
|
||||
|
||||
def update_model(self, updates=None):
|
||||
"""
|
||||
Get or set, whether automatic updates are performed. When updates are
|
||||
|
|
@ -87,21 +87,23 @@ class Observable(object):
|
|||
else:
|
||||
self._updates = updates
|
||||
self.trigger_update()
|
||||
|
||||
|
||||
def toggle_update(self):
|
||||
self.update_model(not self.update())
|
||||
|
||||
def trigger_update(self):
|
||||
def trigger_update(self, trigger_parent=True):
|
||||
"""
|
||||
Update the model from the current state.
|
||||
Make sure that updates are on, otherwise this
|
||||
method will do nothing
|
||||
|
||||
:param bool trigger_parent: Whether to trigger the parent, after self has updated
|
||||
"""
|
||||
if not self.update_model():
|
||||
#print "Warning: updates are off, updating the model will do nothing"
|
||||
return
|
||||
self._trigger_params_changed()
|
||||
|
||||
self._trigger_params_changed(trigger_parent)
|
||||
|
||||
def add_observer(self, observer, callble, priority=0):
|
||||
"""
|
||||
Add an observer `observer` with the callback `callble`
|
||||
|
|
@ -539,18 +541,18 @@ class Indexable(Nameable, Observable):
|
|||
[np.put(ret, ind, p.lnpdf_grad(x[ind])) for p, ind in self.priors.iteritems()]
|
||||
return ret
|
||||
return 0.
|
||||
|
||||
|
||||
#===========================================================================
|
||||
# Tie parameters together
|
||||
#===========================================================================
|
||||
|
||||
|
||||
def _has_ties(self):
|
||||
if self._highest_parent_.tie.tied_param is None:
|
||||
return False
|
||||
if self.has_parent():
|
||||
return self._highest_parent_.tie.label_buf[self._highest_parent_._raveled_index_for(self)].sum()>0
|
||||
return True
|
||||
|
||||
|
||||
def tie_together(self):
|
||||
self._highest_parent_.tie.add_tied_parameter(self)
|
||||
self._highest_parent_._set_fixed(self,self._raveled_index())
|
||||
|
|
@ -740,7 +742,7 @@ class OptimizationHandlable(Indexable):
|
|||
self.param_array.flat[f] = p
|
||||
[np.put(self.param_array, ind[f[ind]], c.f(self.param_array.flat[ind[f[ind]]]))
|
||||
for c, ind in self.constraints.iteritems() if c != __fixed__]
|
||||
self._highest_parent_.tie.propagate_val()
|
||||
#self._highest_parent_.tie.propagate_val()
|
||||
|
||||
self._optimizer_copy_transformed = False
|
||||
self._trigger_params_changed()
|
||||
|
|
@ -826,6 +828,7 @@ class OptimizationHandlable(Indexable):
|
|||
"""
|
||||
# first take care of all parameters (from N(0,1))
|
||||
x = rand_gen(size=self._size_transformed(), *args, **kwargs)
|
||||
updates = self.update_model()
|
||||
self.update_model(False) # Switch off the updates
|
||||
self.optimizer_array = x # makes sure all of the tied parameters get the same init (since there's only one prior object...)
|
||||
# now draw from prior where possible
|
||||
|
|
@ -833,8 +836,8 @@ class OptimizationHandlable(Indexable):
|
|||
[np.put(x, ind, p.rvs(ind.size)) for p, ind in self.priors.iteritems() if not p is None]
|
||||
unfixlist = np.ones((self.size,),dtype=np.bool)
|
||||
unfixlist[self.constraints[__fixed__]] = False
|
||||
self.param_array[unfixlist] = x[unfixlist]
|
||||
self.update_model(True)
|
||||
self.param_array.flat[unfixlist] = x.view(np.ndarray).ravel()[unfixlist]
|
||||
self.update_model(updates)
|
||||
|
||||
#===========================================================================
|
||||
# For shared memory arrays. This does nothing in Param, but sets the memory
|
||||
|
|
@ -928,7 +931,7 @@ class Parameterizable(OptimizationHandlable):
|
|||
"""
|
||||
if self.__dict__.get('_param_array_', None) is None:
|
||||
self._param_array_ = np.empty(self.size, dtype=np.float64)
|
||||
|
||||
|
||||
if self.constraints[__fixed__].size !=0:
|
||||
fixes = np.ones(self.size).astype(bool)
|
||||
fixes[self.constraints[__fixed__]] = FIXED
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue