mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 11:32:39 +02:00
all parameters in memory
This commit is contained in:
parent
f7223ea377
commit
546d5dfff3
9 changed files with 135 additions and 103 deletions
|
|
@ -65,8 +65,8 @@ class Parameterized(Parameterizable, Pickleable):
|
|||
# **Never** call parameters_changed() yourself
|
||||
__metaclass__ = ParametersChangedMeta
|
||||
#===========================================================================
|
||||
def __init__(self, name=None, *a, **kw):
|
||||
super(Parameterized, self).__init__(name=name, parent=None, parent_index=None, *a, **kw)
|
||||
def __init__(self, name=None, parameters=[], *a, **kw):
|
||||
super(Parameterized, self).__init__(name=name, *a, **kw)
|
||||
self._in_init_ = True
|
||||
self._parameters_ = ArrayList()
|
||||
self.size = sum(p.size for p in self._parameters_)
|
||||
|
|
@ -76,6 +76,7 @@ class Parameterized(Parameterizable, Pickleable):
|
|||
self._param_slices_ = []
|
||||
self._connect_parameters()
|
||||
del self._in_init_
|
||||
self.add_parameters(*parameters)
|
||||
|
||||
def build_pydot(self, G=None):
|
||||
import pydot # @UnresolvedImport
|
||||
|
|
@ -205,25 +206,29 @@ class Parameterized(Parameterizable, Pickleable):
|
|||
return found_params
|
||||
|
||||
def __getitem__(self, name, paramlist=None):
|
||||
if paramlist is None:
|
||||
paramlist = self.grep_param_names(name)
|
||||
if len(paramlist) < 1: raise AttributeError, name
|
||||
if len(paramlist) == 1:
|
||||
if isinstance(paramlist[-1], Parameterized):
|
||||
paramlist = paramlist[-1].flattened_parameters
|
||||
if len(paramlist) != 1:
|
||||
return ParamConcatenation(paramlist)
|
||||
return paramlist[-1]
|
||||
return ParamConcatenation(paramlist)
|
||||
if isinstance(name, (int, slice, tuple, np.ndarray)):
|
||||
return self._param_array_[name]
|
||||
else:
|
||||
if paramlist is None:
|
||||
paramlist = self.grep_param_names(name)
|
||||
if len(paramlist) < 1: raise AttributeError, name
|
||||
if len(paramlist) == 1:
|
||||
if isinstance(paramlist[-1], Parameterized):
|
||||
paramlist = paramlist[-1].flattened_parameters
|
||||
if len(paramlist) != 1:
|
||||
return ParamConcatenation(paramlist)
|
||||
return paramlist[-1]
|
||||
return ParamConcatenation(paramlist)
|
||||
|
||||
def __setitem__(self, name, value, paramlist=None):
|
||||
if isinstance(name, (slice, tuple, np.ndarray)):
|
||||
self._param_array_[name] = value
|
||||
self.notify_observers()
|
||||
else:
|
||||
try: param = self.__getitem__(name, paramlist)
|
||||
except AttributeError as a: raise a
|
||||
param[:] = value
|
||||
|
||||
|
||||
def __setattr__(self, name, val):
|
||||
# override the default behaviour, if setting a param, so broadcasting can by used
|
||||
if hasattr(self, '_parameters_'):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue