mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
new getters and setters for self.params, added m['var'] getter and setter
This commit is contained in:
parent
2096d062cb
commit
000cd5dbb8
2 changed files with 160 additions and 123 deletions
|
|
@ -84,37 +84,6 @@ class model(parameterised):
|
|||
for w in which:
|
||||
self.priors[w] = what
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.get(name)
|
||||
|
||||
def __setitem(self, name, val):
|
||||
return self.set(name, val)
|
||||
|
||||
def get(self, name, return_names=False):
|
||||
"""
|
||||
Get a model parameter by name. The name is applied as a regular expression and all parameters that match that regular expression are returned.
|
||||
"""
|
||||
matches = self.grep_param_names(name)
|
||||
if len(matches):
|
||||
if return_names:
|
||||
return self._get_params()[matches], np.asarray(self._get_param_names())[matches].tolist()
|
||||
else:
|
||||
return self._get_params()[matches]
|
||||
else:
|
||||
raise AttributeError, "no parameter matches %s" % name
|
||||
|
||||
def set(self, name, val):
|
||||
"""
|
||||
Set model parameter(s) by name. The name is provided as a regular expression. All parameters matching that regular expression are set to ghe given value.
|
||||
"""
|
||||
matches = self.grep_param_names(name)
|
||||
if len(matches):
|
||||
x = self._get_params()
|
||||
x[matches] = val
|
||||
self._set_params(x)
|
||||
else:
|
||||
raise AttributeError, "no parameter matches %s" % name
|
||||
|
||||
def get_gradient(self, name, return_names=False):
|
||||
"""
|
||||
Get model gradient(s) by name. The name is applied as a regular expression and all parameters that match that regular expression are returned.
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import copy
|
|||
import cPickle
|
||||
import os
|
||||
from ..util.squashers import sigmoid
|
||||
import warnings
|
||||
|
||||
def truncate_pad(string, width, align='m'):
|
||||
"""
|
||||
|
|
@ -55,6 +56,73 @@ class parameterised(object):
|
|||
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
"""
|
||||
Returns a **copy** of parameters in non transformed space
|
||||
|
||||
:see_also: :py:func:`GPy.core.parameterised.params_transformed`
|
||||
"""
|
||||
return self._get_params()
|
||||
@params.setter
|
||||
def params(self, params):
|
||||
self._set_params(params)
|
||||
|
||||
@property
|
||||
def params_transformed(self):
|
||||
"""
|
||||
Returns a **copy** of parameters in transformed space
|
||||
|
||||
:see_also: :py:func:`GPy.core.parameterised.params`
|
||||
"""
|
||||
return self._get_params_transformed()
|
||||
@params_transformed.setter
|
||||
def params_transformed(self, params):
|
||||
self._set_params_transformed(params)
|
||||
|
||||
_get_set_deprecation = """get and set methods wont be available at next minor release
|
||||
in the next releases you will get and set with following syntax:
|
||||
Assume m is a model class:
|
||||
print m['var'] # > prints all parameters matching 'var'
|
||||
m['var'] = 2. # > sets all parameters matching 'var' to 2.
|
||||
m['var'] = <array-like> # > sets parameters matching 'var' to <array-like>
|
||||
"""
|
||||
def get(self, name):
|
||||
warnings.warn(self._get_set_deprecation, FutureWarning, stacklevel=2)
|
||||
return self[name]
|
||||
|
||||
def set(self, name, val):
|
||||
warnings.warn(self._get_set_deprecation, FutureWarning, stacklevel=2)
|
||||
self[name] = val
|
||||
|
||||
def __getitem__(self, name, return_names=False):
|
||||
"""
|
||||
Get a model parameter by name. The name is applied as a regular expression and all parameters that match that regular expression are returned.
|
||||
"""
|
||||
matches = self.grep_param_names(name)
|
||||
if len(matches):
|
||||
if return_names:
|
||||
return self._get_params()[matches], np.asarray(self._get_param_names())[matches].tolist()
|
||||
else:
|
||||
return self._get_params()[matches]
|
||||
else:
|
||||
raise AttributeError, "no parameter matches %s" % name
|
||||
|
||||
def __setitem__(self, name, val):
|
||||
"""
|
||||
Set model parameter(s) by name. The name is provided as a regular expression. All parameters matching that regular expression are set to ghe given value.
|
||||
"""
|
||||
matches = self.grep_param_names(name)
|
||||
if len(matches):
|
||||
val = np.array(val)
|
||||
assert (val.size == 1) or val.size == len(matches), "Shape mismatch: {}:({},)".format(val.size, len(matches))
|
||||
x = self.params
|
||||
x[matches] = val
|
||||
self.params = x
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# self.params[matches] = val
|
||||
else:
|
||||
raise AttributeError, "no parameter matches %s" % name
|
||||
|
||||
def tie_params(self, which):
|
||||
matches = self.grep_param_names(which)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue