mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-02 14:45:15 +02:00
copy and missing data
This commit is contained in:
parent
0082acad63
commit
c4f6b0dbe7
10 changed files with 179 additions and 97 deletions
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import itertools
|
||||
import numpy
|
||||
from parameter_core import Constrainable, Gradcheckable, Indexable, Parameterizable, adjust_name_for_printing
|
||||
from parameter_core import Constrainable, Gradcheckable, Indexable, Parentable, adjust_name_for_printing
|
||||
from array_core import ObservableArray, ParamList
|
||||
|
||||
###### printing
|
||||
|
|
@ -15,7 +15,7 @@ __precision__ = numpy.get_printoptions()['precision'] # numpy printing precision
|
|||
__print_threshold__ = 5
|
||||
######
|
||||
|
||||
class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parameterizable):
|
||||
class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parentable):
|
||||
"""
|
||||
Parameter object for GPy models.
|
||||
|
||||
|
|
@ -114,7 +114,14 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parameteri
|
|||
self._parent_index_ = state.pop()
|
||||
self._direct_parent_ = state.pop()
|
||||
self.name = state.pop()
|
||||
|
||||
|
||||
def copy(self, *args):
|
||||
constr = self.constraints.copy()
|
||||
priors = self.priors.copy()
|
||||
p = Param(self.name, self.view(numpy.ndarray).copy(), self._default_constraint_)
|
||||
p.constraints = constr
|
||||
p.priors = priors
|
||||
return p
|
||||
#===========================================================================
|
||||
# get/set parameters
|
||||
#===========================================================================
|
||||
|
|
|
|||
|
|
@ -68,6 +68,10 @@ class Parentable(object):
|
|||
return self
|
||||
return self._direct_parent_._highest_parent_
|
||||
|
||||
def _notify_parameters_changed(self):
|
||||
if self.has_parent():
|
||||
self._direct_parent_._notify_parameters_changed()
|
||||
|
||||
class Nameable(Parentable):
|
||||
_name = None
|
||||
def __init__(self, name, direct_parent=None, parent_index=None):
|
||||
|
|
@ -80,22 +84,47 @@ class Nameable(Parentable):
|
|||
@name.setter
|
||||
def name(self, name):
|
||||
from_name = self.name
|
||||
assert isinstance(name, str)
|
||||
self._name = name
|
||||
if self.has_parent():
|
||||
self._direct_parent_._name_changed(self, from_name)
|
||||
|
||||
self._direct_parent_._name_changed(self, from_name)
|
||||
|
||||
class Parameterizable(Parentable):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Parameterizable, self).__init__(*args, **kwargs)
|
||||
from GPy.core.parameterization.array_core import ParamList
|
||||
_parameters_ = ParamList()
|
||||
self._added_names_ = set()
|
||||
|
||||
def parameter_names(self, add_name=False):
|
||||
if add_name:
|
||||
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
|
||||
return [xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
|
||||
|
||||
def _add_parameter_name(self, param):
|
||||
pname = adjust_name_for_printing(param.name)
|
||||
# and makes sure to not delete programmatically added parameters
|
||||
if pname in self.__dict__:
|
||||
if not (param is self.__dict__[pname]):
|
||||
if pname in self._added_names_:
|
||||
del self.__dict__[pname]
|
||||
self._add_parameter_name(param)
|
||||
else:
|
||||
self.__dict__[pname] = param
|
||||
self._added_names_.add(pname)
|
||||
|
||||
def _remove_parameter_name(self, param=None, pname=None):
|
||||
assert param is None or pname is None, "can only delete either param by name, or the name of a param"
|
||||
pname = adjust_name_for_printing(pname) or adjust_name_for_printing(param.name)
|
||||
if pname in self._added_names_:
|
||||
del self.__dict__[pname]
|
||||
self._added_names_.remove(pname)
|
||||
self._connect_parameters()
|
||||
|
||||
def _name_changed(self, param, old_name):
|
||||
self._remove_parameter_name(None, old_name)
|
||||
self._add_parameter_name(param)
|
||||
|
||||
def _collect_gradient(self, target):
|
||||
import itertools
|
||||
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
|
|
@ -113,6 +142,38 @@ class Parameterizable(Parentable):
|
|||
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
self.parameters_changed()
|
||||
|
||||
def copy(self):
|
||||
"""Returns a (deep) copy of the current model"""
|
||||
import copy
|
||||
from .index_operations import ParameterIndexOperations, ParameterIndexOperationsView
|
||||
from .array_core import ParamList
|
||||
dc = dict()
|
||||
for k, v in self.__dict__.iteritems():
|
||||
if k not in ['_direct_parent_', '_parameters_', '_parent_index_'] + self.parameter_names():
|
||||
if isinstance(v, (Constrainable, ParameterIndexOperations, ParameterIndexOperationsView)):
|
||||
dc[k] = v.copy()
|
||||
else:
|
||||
dc[k] = copy.deepcopy(v)
|
||||
if k == '_parameters_':
|
||||
params = [p.copy() for p in v]
|
||||
#dc = copy.deepcopy(self.__dict__)
|
||||
dc['_direct_parent_'] = None
|
||||
dc['_parent_index_'] = None
|
||||
dc['_parameters_'] = ParamList()
|
||||
s = self.__new__(self.__class__)
|
||||
s.__dict__ = dc
|
||||
#import ipdb;ipdb.set_trace()
|
||||
for p in params:
|
||||
s.add_parameter(p)
|
||||
#dc._notify_parent_change()
|
||||
return s
|
||||
#return copy.deepcopy(self)
|
||||
|
||||
def _notify_parameters_changed(self):
|
||||
self.parameters_changed()
|
||||
if self.has_parent():
|
||||
self._direct_parent_._notify_parameters_changed()
|
||||
|
||||
def parameters_changed(self):
|
||||
"""
|
||||
This method gets called when parameters have changed.
|
||||
|
|
@ -122,11 +183,6 @@ class Parameterizable(Parentable):
|
|||
"""
|
||||
pass
|
||||
|
||||
def _notify_parameters_changed(self):
|
||||
self.parameters_changed()
|
||||
if self.has_parent():
|
||||
self._direct_parent_._notify_parameters_changed()
|
||||
|
||||
|
||||
class Gradcheckable(Parentable):
|
||||
#===========================================================================
|
||||
|
|
@ -157,7 +213,7 @@ class Indexable(object):
|
|||
"""
|
||||
raise NotImplementedError, "shouldnt happen, raveld index transformation required from non parameterization object?"
|
||||
|
||||
class Constrainable(Nameable, Indexable, Parameterizable):
|
||||
class Constrainable(Nameable, Indexable, Parentable):
|
||||
def __init__(self, name, default_constraint=None):
|
||||
super(Constrainable,self).__init__(name)
|
||||
self._default_constraint_ = default_constraint
|
||||
|
|
@ -167,6 +223,16 @@ class Constrainable(Nameable, Indexable, Parameterizable):
|
|||
if self._default_constraint_ is not None:
|
||||
self.constrain(self._default_constraint_)
|
||||
|
||||
def _disconnect_parent(self, constr=None):
|
||||
if constr is None:
|
||||
constr = self.constraints.copy()
|
||||
self.constraints.clear()
|
||||
self.constraints = constr
|
||||
self._direct_parent_ = None
|
||||
self._parent_index_ = None
|
||||
self._connect_fixes()
|
||||
self._notify_parent_change()
|
||||
|
||||
#===========================================================================
|
||||
# Fixing Parameters:
|
||||
#===========================================================================
|
||||
|
|
|
|||
|
|
@ -3,16 +3,15 @@
|
|||
|
||||
|
||||
import numpy; np = numpy
|
||||
import copy
|
||||
import cPickle
|
||||
import itertools
|
||||
from re import compile, _pattern_type
|
||||
from param import ParamConcatenation, Param
|
||||
from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing, Gradcheckable
|
||||
from transformations import __fixed__, FIXED, UNFIXED
|
||||
from param import ParamConcatenation
|
||||
from parameter_core import Constrainable, Pickleable, Observable, Parameterizable, adjust_name_for_printing, Gradcheckable
|
||||
from transformations import __fixed__
|
||||
from array_core import ParamList
|
||||
|
||||
class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
||||
class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable, Parameterizable):
|
||||
"""
|
||||
Parameterized class
|
||||
|
||||
|
|
@ -63,7 +62,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
self._fixes_ = None
|
||||
self._param_slices_ = []
|
||||
self._connect_parameters()
|
||||
self._added_names_ = set()
|
||||
del self._in_init_
|
||||
|
||||
def add_parameter(self, param, index=None):
|
||||
|
|
@ -117,17 +115,10 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
raise RuntimeError, "Parameter {} does not belong to this object, remove parameters directly from their respective parents".format(param._short())
|
||||
del self._parameters_[param._parent_index_]
|
||||
self.size -= param.size
|
||||
constr = param.constraints.copy()
|
||||
param.constraints.clear()
|
||||
param.constraints = constr
|
||||
param._direct_parent_ = None
|
||||
param._parent_index_ = None
|
||||
param._connect_fixes()
|
||||
param._notify_parent_change()
|
||||
pname = adjust_name_for_printing(param.name)
|
||||
if pname in self._added_names_:
|
||||
del self.__dict__[pname]
|
||||
self._connect_parameters()
|
||||
|
||||
param._disconnect_parent()
|
||||
self._remove_parameter_name(param)
|
||||
|
||||
#self._notify_parent_change()
|
||||
self._connect_fixes()
|
||||
|
||||
|
|
@ -145,19 +136,9 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
for i, p in enumerate(self._parameters_):
|
||||
p._direct_parent_ = self
|
||||
p._parent_index_ = i
|
||||
not_unique = []
|
||||
sizes.append(p.size + sizes[-1])
|
||||
self._param_slices_.append(slice(sizes[-2], sizes[-1]))
|
||||
pname = adjust_name_for_printing(p.name)
|
||||
# and makes sure to not delete programmatically added parameters
|
||||
if pname in self.__dict__:
|
||||
if isinstance(self.__dict__[pname], (Parameterized, Param)):
|
||||
if not p is self.__dict__[pname]:
|
||||
not_unique.append(pname)
|
||||
del self.__dict__[pname]
|
||||
elif not (pname in not_unique):
|
||||
self.__dict__[pname] = p
|
||||
self._added_names_.add(pname)
|
||||
self._add_parameter_name(p)
|
||||
|
||||
#===========================================================================
|
||||
# Pickling operations
|
||||
|
|
@ -174,19 +155,7 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
cPickle.dump(self, f, protocol)
|
||||
else:
|
||||
cPickle.dump(self, f, protocol)
|
||||
def copy(self):
|
||||
"""Returns a (deep) copy of the current model """
|
||||
# dc = dict()
|
||||
# for k, v in self.__dict__.iteritems():
|
||||
# if k not in ['_highest_parent_', '_direct_parent_']:
|
||||
# dc[k] = copy.deepcopy(v)
|
||||
|
||||
# dc = copy.deepcopy(self.__dict__)
|
||||
# dc['_highest_parent_'] = None
|
||||
# dc['_direct_parent_'] = None
|
||||
# s = self.__class__.new()
|
||||
# s.__dict__ = dc
|
||||
return copy.deepcopy(self)
|
||||
def __getstate__(self):
|
||||
if self._has_get_set_state():
|
||||
return self._getstate()
|
||||
|
|
@ -265,14 +234,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
if self._has_fixes(): tmp = self._get_params(); tmp[self._fixes_] = p; p = tmp; del tmp
|
||||
[numpy.put(p, ind, c.f(p[ind])) for c, ind in self.constraints.iteritems() if c != __fixed__]
|
||||
return p
|
||||
def _name_changed(self, param, old_name):
|
||||
if hasattr(self, old_name) and old_name in self._added_names_:
|
||||
delattr(self, old_name)
|
||||
self._added_names_.remove(old_name)
|
||||
pname = adjust_name_for_printing(param.name)
|
||||
if pname not in self.__dict__:
|
||||
self._added_names_.add(pname)
|
||||
self.__dict__[pname] = param
|
||||
#===========================================================================
|
||||
# Indexable Handling
|
||||
#===========================================================================
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue