mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 19:42:39 +02:00
observer pattern has a handle to trigger only > min_priority observers
This commit is contained in:
parent
058ab679e7
commit
2771e3f71f
8 changed files with 181 additions and 90 deletions
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
import itertools
|
||||
import numpy
|
||||
from parameter_core import Constrainable, Gradcheckable, Indexable, Parentable, adjust_name_for_printing
|
||||
from array_core import ObservableArray, ParamList
|
||||
from parameter_core import OptimizationHandlable, Gradcheckable, adjust_name_for_printing
|
||||
from array_core import ObservableArray
|
||||
|
||||
###### printing
|
||||
__constraints_name__ = "Constraint"
|
||||
|
|
@ -15,7 +15,7 @@ __precision__ = numpy.get_printoptions()['precision'] # numpy printing precision
|
|||
__print_threshold__ = 5
|
||||
######
|
||||
|
||||
class Param(Constrainable, ObservableArray, Gradcheckable):
|
||||
class Param(OptimizationHandlable, ObservableArray, Gradcheckable):
|
||||
"""
|
||||
Parameter object for GPy models.
|
||||
|
||||
|
|
@ -148,8 +148,11 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
|||
#===========================================================================
|
||||
# get/set parameters
|
||||
#===========================================================================
|
||||
def _set_params(self, param, update=True):
|
||||
def _set_params(self, param, trigger_parent=True):
|
||||
self.flat = param
|
||||
if trigger_parent: min_priority = None
|
||||
else: min_priority = -numpy.inf
|
||||
self._notify_observers(None, min_priority)
|
||||
|
||||
def _get_params(self):
|
||||
return self.flat
|
||||
|
|
@ -175,9 +178,6 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
|||
|
||||
def __setitem__(self, s, val):
|
||||
super(Param, self).__setitem__(s, val)
|
||||
if self.has_parent():
|
||||
self._direct_parent_._notify_parameters_changed()
|
||||
#self._notify_observers()
|
||||
|
||||
#===========================================================================
|
||||
# Index Operations:
|
||||
|
|
@ -205,6 +205,7 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
|||
ind = self._indices(slice_index)
|
||||
if ind.ndim < 2: ind = ind[:, None]
|
||||
return numpy.asarray(numpy.apply_along_axis(lambda x: numpy.sum(extended_realshape * x), 1, ind), dtype=int)
|
||||
|
||||
def _expand_index(self, slice_index=None):
|
||||
# this calculates the full indexing arrays from the slicing objects given by get_item for _real..._ attributes
|
||||
# it basically translates slices to their respective index arrays and turns negative indices around
|
||||
|
|
@ -346,6 +347,7 @@ class ParamConcatenation(object):
|
|||
See :py:class:`GPy.core.parameter.Param` for more details on constraining.
|
||||
"""
|
||||
# self.params = params
|
||||
from lists_and_dicts import ParamList
|
||||
self.params = ParamList([])
|
||||
for p in params:
|
||||
for p in p.flattened_parameters:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue