observer pattern has a handle to trigger only > min_priority observers

This commit is contained in:
Max Zwiessele 2014-02-28 11:33:31 +00:00
parent 058ab679e7
commit 2771e3f71f
8 changed files with 181 additions and 90 deletions

View file

@ -3,8 +3,8 @@
import itertools
import numpy
from parameter_core import Constrainable, Gradcheckable, Indexable, Parentable, adjust_name_for_printing
from array_core import ObservableArray, ParamList
from parameter_core import OptimizationHandlable, Gradcheckable, adjust_name_for_printing
from array_core import ObservableArray
###### printing
__constraints_name__ = "Constraint"
@ -15,7 +15,7 @@ __precision__ = numpy.get_printoptions()['precision'] # numpy printing precision
__print_threshold__ = 5
######
class Param(Constrainable, ObservableArray, Gradcheckable):
class Param(OptimizationHandlable, ObservableArray, Gradcheckable):
"""
Parameter object for GPy models.
@ -148,8 +148,11 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
#===========================================================================
# get/set parameters
#===========================================================================
def _set_params(self, param, update=True):
def _set_params(self, param, trigger_parent=True):
self.flat = param
if trigger_parent: min_priority = None
else: min_priority = -numpy.inf
self._notify_observers(None, min_priority)
def _get_params(self):
return self.flat
@ -175,9 +178,6 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
def __setitem__(self, s, val):
super(Param, self).__setitem__(s, val)
if self.has_parent():
self._direct_parent_._notify_parameters_changed()
#self._notify_observers()
#===========================================================================
# Index Operations:
@ -205,6 +205,7 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
ind = self._indices(slice_index)
if ind.ndim < 2: ind = ind[:, None]
return numpy.asarray(numpy.apply_along_axis(lambda x: numpy.sum(extended_realshape * x), 1, ind), dtype=int)
def _expand_index(self, slice_index=None):
# this calculates the full indexing arrays from the slicing objects given by get_item for _real..._ attributes
# it basically translates slices to their respective index arrays and turns negative indices around
@ -346,6 +347,7 @@ class ParamConcatenation(object):
See :py:class:`GPy.core.parameter.Param` for more details on constraining.
"""
# self.params = params
from lists_and_dicts import ParamList
self.params = ParamList([])
for p in params:
for p in p.flattened_parameters: