mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 19:42:39 +02:00
enabled some more getting/setting parameters, such as regular expressions and params
This commit is contained in:
parent
01c795ae10
commit
333e24a1c3
5 changed files with 70 additions and 12 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
__updated__ = '2014-03-31'
|
__updated__ = '2014-04-15'
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parameter_core import Observable, Pickleable
|
from parameter_core import Observable, Pickleable
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import numpy
|
import numpy
|
||||||
|
np = numpy
|
||||||
from parameter_core import OptimizationHandlable, adjust_name_for_printing
|
from parameter_core import OptimizationHandlable, adjust_name_for_printing
|
||||||
from observable_array import ObsAr
|
from observable_array import ObsAr
|
||||||
|
|
||||||
|
|
@ -118,10 +119,6 @@ class Param(OptimizationHandlable, ObsAr):
|
||||||
except AttributeError: pass # returning 0d array or float, double etc
|
except AttributeError: pass # returning 0d array or float, double etc
|
||||||
return new_arr
|
return new_arr
|
||||||
|
|
||||||
def __setitem__(self, s, val):
|
|
||||||
super(Param, self).__setitem__(s, val)
|
|
||||||
|
|
||||||
|
|
||||||
def _raveled_index(self, slice_index=None):
|
def _raveled_index(self, slice_index=None):
|
||||||
# return an index array on the raveled array, which is formed by the current_slice
|
# return an index array on the raveled array, which is formed by the current_slice
|
||||||
# of this object
|
# of this object
|
||||||
|
|
@ -311,15 +308,15 @@ class ParamConcatenation(object):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def __getitem__(self, s):
|
def __getitem__(self, s):
|
||||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||||
params = [p.param_array[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p.param_array[ind[ps]])]
|
params = [p.param_array.flat[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p.param_array.flat[ind[ps]])]
|
||||||
if len(params)==1: return params[0]
|
if len(params)==1: return params[0]
|
||||||
return ParamConcatenation(params)
|
return ParamConcatenation(params)
|
||||||
def __setitem__(self, s, val, update=True):
|
def __setitem__(self, s, val, update=True):
|
||||||
if isinstance(val, ParamConcatenation):
|
if isinstance(val, ParamConcatenation):
|
||||||
val = val.values()
|
val = val.values()
|
||||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||||
vals = self.values(); vals[s] = val; del val
|
vals = self.values(); vals[s] = val
|
||||||
[numpy.place(p, ind[ps], vals[ps])
|
[numpy.copyto(p, vals[ps], where=ind[ps])
|
||||||
for p, ps in zip(self.params, self._param_slices_)]
|
for p, ps in zip(self.params, self._param_slices_)]
|
||||||
if update:
|
if update:
|
||||||
self.update_all_params()
|
self.update_all_params()
|
||||||
|
|
@ -411,3 +408,42 @@ class ParamConcatenation(object):
|
||||||
return "\n".join(strings)
|
return "\n".join(strings)
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "\n".join(map(repr,self.params))
|
return "\n".join(map(repr,self.params))
|
||||||
|
|
||||||
|
def __ilshift__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__ilshift__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __irshift__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__irshift__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __ixor__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__ixor__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __ipow__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__ipow__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __ifloordiv__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__ifloordiv__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __isub__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__isub__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __ior__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__ior__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __itruediv__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__itruediv__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __idiv__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__idiv__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __iand__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__iand__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __imod__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__imod__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __iadd__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__iadd__(self.values(), *args, **kwargs)
|
||||||
|
|
||||||
|
def __imul__(self, *args, **kwargs):
|
||||||
|
self[:] = np.ndarray.__imul__(self.values(), *args, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,9 @@ Observable Pattern for patameterization
|
||||||
|
|
||||||
from transformations import Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
|
from transformations import Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import re
|
||||||
|
|
||||||
__updated__ = '2014-03-31'
|
__updated__ = '2014-04-15'
|
||||||
|
|
||||||
class HierarchyError(Exception):
|
class HierarchyError(Exception):
|
||||||
"""
|
"""
|
||||||
|
|
@ -28,7 +29,15 @@ def adjust_name_for_printing(name):
|
||||||
Make sure a name can be printed, alongside used as a variable name.
|
Make sure a name can be printed, alongside used as a variable name.
|
||||||
"""
|
"""
|
||||||
if name is not None:
|
if name is not None:
|
||||||
return name.replace(" ", "_").replace(".", "_").replace("-", "_m_").replace("+", "_p_").replace("!", "_I_").replace("**", "_xx_").replace("*", "_x_").replace("/", "_l_").replace("@", '_at_')
|
name2 = name
|
||||||
|
name = name.replace(" ", "_").replace(".", "_").replace("-", "_m_")
|
||||||
|
name = name.replace("+", "_p_").replace("!", "_I_")
|
||||||
|
name = name.replace("**", "_xx_").replace("*", "_x_")
|
||||||
|
name = name.replace("/", "_l_").replace("@", '_at_')
|
||||||
|
name = name.replace("(", "_of_").replace(")", "")
|
||||||
|
if re.match(r'^[a-zA-Z_][a-zA-Z0-9-_]*$', name) is None:
|
||||||
|
raise NameError, "name {} converted to {} cannot be further converted to valid python variable name!".format(name2, name)
|
||||||
|
return name
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -185,6 +185,8 @@ class Parameterized(Parameterizable):
|
||||||
return ParamConcatenation(paramlist)
|
return ParamConcatenation(paramlist)
|
||||||
|
|
||||||
def __setitem__(self, name, value, paramlist=None):
|
def __setitem__(self, name, value, paramlist=None):
|
||||||
|
if value is None:
|
||||||
|
return # nothing to do here
|
||||||
if isinstance(name, (slice, tuple, np.ndarray)):
|
if isinstance(name, (slice, tuple, np.ndarray)):
|
||||||
try:
|
try:
|
||||||
self.param_array[name] = value
|
self.param_array[name] = value
|
||||||
|
|
@ -197,8 +199,8 @@ class Parameterized(Parameterizable):
|
||||||
param[:] = value
|
param[:] = value
|
||||||
|
|
||||||
def __setattr__(self, name, val):
|
def __setattr__(self, name, val):
|
||||||
# override the default behaviour, if setting a param, so broadcasting can by used
|
# override the default behaviour, if setting a param, so broadcasting can by used
|
||||||
if hasattr(self, '_parameters_'):
|
if hasattr(self, "_parameters_"):
|
||||||
pnames = self.parameter_names(False, adjust_for_printing=True, recursive=False)
|
pnames = self.parameter_names(False, adjust_for_printing=True, recursive=False)
|
||||||
if name in pnames: self._parameters_[pnames.index(name)][:] = val; return
|
if name in pnames: self._parameters_[pnames.index(name)][:] = val; return
|
||||||
object.__setattr__(self, name, val);
|
object.__setattr__(self, name, val);
|
||||||
|
|
|
||||||
|
|
@ -130,6 +130,17 @@ class MiscTests(unittest.TestCase):
|
||||||
m2.kern[:] = m.kern[''].values()
|
m2.kern[:] = m.kern[''].values()
|
||||||
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())
|
np.testing.assert_equal(m.log_likelihood(), m2.log_likelihood())
|
||||||
|
|
||||||
|
def test_model_set_params(self):
|
||||||
|
m = GPy.models.GPRegression(self.X, self.Y)
|
||||||
|
lengthscale = np.random.uniform()
|
||||||
|
m.kern.lengthscale = lengthscale
|
||||||
|
np.testing.assert_equal(m.kern.lengthscale, lengthscale)
|
||||||
|
m.kern.lengthscale *= 1
|
||||||
|
m['.*var'] -= .1
|
||||||
|
np.testing.assert_equal(m.kern.lengthscale, lengthscale)
|
||||||
|
m.optimize()
|
||||||
|
print m
|
||||||
|
|
||||||
def test_model_optimize(self):
|
def test_model_optimize(self):
|
||||||
X = np.random.uniform(-3., 3., (20, 1))
|
X = np.random.uniform(-3., 3., (20, 1))
|
||||||
Y = np.sin(X) + np.random.randn(20, 1) * 0.05
|
Y = np.sin(X) + np.random.randn(20, 1) * 0.05
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue