Merge branch 'params' of https://github.com/SheffieldML/GPy into params

This commit is contained in:
Neil Lawrence 2014-04-17 07:05:20 -04:00
commit 483cb7ddc0
15 changed files with 376 additions and 495 deletions

View file

@ -1,7 +1,7 @@
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
__updated__ = '2014-03-31'
__updated__ = '2014-04-15'
import numpy as np
from parameter_core import Observable, Pickleable

View file

@ -3,6 +3,7 @@
import itertools
import numpy
np = numpy
from parameter_core import OptimizationHandlable, adjust_name_for_printing
from observable_array import ObsAr
@ -118,10 +119,6 @@ class Param(OptimizationHandlable, ObsAr):
except AttributeError: pass # returning 0d array or float, double etc
return new_arr
def __setitem__(self, s, val):
super(Param, self).__setitem__(s, val)
def _raveled_index(self, slice_index=None):
# return an index array on the raveled array, which is formed by the current_slice
# of this object
@ -311,15 +308,15 @@ class ParamConcatenation(object):
#===========================================================================
def __getitem__(self, s):
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
params = [p.param_array[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p.param_array[ind[ps]])]
params = [p.param_array.flat[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p.param_array.flat[ind[ps]])]
if len(params)==1: return params[0]
return ParamConcatenation(params)
def __setitem__(self, s, val, update=True):
if isinstance(val, ParamConcatenation):
val = val.values()
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
vals = self.values(); vals[s] = val; del val
[numpy.place(p, ind[ps], vals[ps])
vals = self.values(); vals[s] = val
[numpy.copyto(p, vals[ps], where=ind[ps])
for p, ps in zip(self.params, self._param_slices_)]
if update:
self.update_all_params()
@ -342,8 +339,8 @@ class ParamConcatenation(object):
self.update_all_params()
constrain_positive.__doc__ = Param.constrain_positive.__doc__
def constrain_fixed(self, warning=True):
[param.constrain_fixed(warning) for param in self.params]
def constrain_fixed(self, value=None, warning=True, trigger_parent=True):
[param.constrain_fixed(value, warning, trigger_parent) for param in self.params]
constrain_fixed.__doc__ = Param.constrain_fixed.__doc__
fix = constrain_fixed
@ -411,3 +408,42 @@ class ParamConcatenation(object):
return "\n".join(strings)
def __repr__(self):
return "\n".join(map(repr,self.params))
def __ilshift__(self, *args, **kwargs):
self[:] = np.ndarray.__ilshift__(self.values(), *args, **kwargs)
def __irshift__(self, *args, **kwargs):
self[:] = np.ndarray.__irshift__(self.values(), *args, **kwargs)
def __ixor__(self, *args, **kwargs):
self[:] = np.ndarray.__ixor__(self.values(), *args, **kwargs)
def __ipow__(self, *args, **kwargs):
self[:] = np.ndarray.__ipow__(self.values(), *args, **kwargs)
def __ifloordiv__(self, *args, **kwargs):
self[:] = np.ndarray.__ifloordiv__(self.values(), *args, **kwargs)
def __isub__(self, *args, **kwargs):
self[:] = np.ndarray.__isub__(self.values(), *args, **kwargs)
def __ior__(self, *args, **kwargs):
self[:] = np.ndarray.__ior__(self.values(), *args, **kwargs)
def __itruediv__(self, *args, **kwargs):
self[:] = np.ndarray.__itruediv__(self.values(), *args, **kwargs)
def __idiv__(self, *args, **kwargs):
self[:] = np.ndarray.__idiv__(self.values(), *args, **kwargs)
def __iand__(self, *args, **kwargs):
self[:] = np.ndarray.__iand__(self.values(), *args, **kwargs)
def __imod__(self, *args, **kwargs):
self[:] = np.ndarray.__imod__(self.values(), *args, **kwargs)
def __iadd__(self, *args, **kwargs):
self[:] = np.ndarray.__iadd__(self.values(), *args, **kwargs)
def __imul__(self, *args, **kwargs):
self[:] = np.ndarray.__imul__(self.values(), *args, **kwargs)

View file

@ -15,8 +15,9 @@ Observable Pattern for patameterization
from transformations import Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
import numpy as np
import re
__updated__ = '2014-03-31'
__updated__ = '2014-04-16'
class HierarchyError(Exception):
"""
@ -28,7 +29,15 @@ def adjust_name_for_printing(name):
Make sure a name can be printed, alongside used as a variable name.
"""
if name is not None:
return name.replace(" ", "_").replace(".", "_").replace("-", "_m_").replace("+", "_p_").replace("!", "_I_").replace("**", "_xx_").replace("*", "_x_").replace("/", "_l_").replace("@", '_at_')
name2 = name
name = name.replace(" ", "_").replace(".", "_").replace("-", "_m_")
name = name.replace("+", "_p_").replace("!", "_I_")
name = name.replace("**", "_xx_").replace("*", "_x_")
name = name.replace("/", "_l_").replace("@", '_at_')
name = name.replace("(", "_of_").replace(")", "")
if re.match(r'^[a-zA-Z_][a-zA-Z0-9-_]*$', name) is None:
raise NameError, "name {} converted to {} cannot be further converted to valid python variable name!".format(name2, name)
return name
return ''
@ -458,7 +467,7 @@ class Constrainable(Nameable, Indexable, Observable):
Constrain the parameter to the given
:py:class:`GPy.core.transformations.Transformation`.
"""
self.param_array[:] = transform.initialize(self.param_array)
self.param_array[...] = transform.initialize(self.param_array)
reconstrained = self.unconstrain()
self._add_to_index_operations(self.constraints, reconstrained, transform, warning)
self.notify_observers(self, None if trigger_parent else -np.inf)

View file

@ -185,6 +185,8 @@ class Parameterized(Parameterizable):
return ParamConcatenation(paramlist)
def __setitem__(self, name, value, paramlist=None):
if value is None:
return # nothing to do here
if isinstance(name, (slice, tuple, np.ndarray)):
try:
self.param_array[name] = value
@ -197,8 +199,8 @@ class Parameterized(Parameterizable):
param[:] = value
def __setattr__(self, name, val):
# override the default behaviour, if setting a param, so broadcasting can by used
if hasattr(self, '_parameters_'):
# override the default behaviour, if setting a param, so broadcasting can by used
if hasattr(self, "_parameters_"):
pnames = self.parameter_names(False, adjust_for_printing=True, recursive=False)
if name in pnames: self._parameters_[pnames.index(name)][:] = val; return
object.__setattr__(self, name, val);