mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
merge current devel in
This commit is contained in:
commit
09589fb50f
114 changed files with 7673 additions and 13114 deletions
|
|
@ -4,7 +4,7 @@
|
|||
import itertools
|
||||
import numpy
|
||||
np = numpy
|
||||
from parameter_core import OptimizationHandlable, adjust_name_for_printing
|
||||
from parameter_core import Parameterizable, adjust_name_for_printing, Pickleable
|
||||
from observable_array import ObsAr
|
||||
|
||||
###### printing
|
||||
|
|
@ -16,7 +16,7 @@ __precision__ = numpy.get_printoptions()['precision'] # numpy printing precision
|
|||
__print_threshold__ = 5
|
||||
######
|
||||
|
||||
class Param(OptimizationHandlable, ObsAr):
|
||||
class Param(Parameterizable, ObsAr):
|
||||
"""
|
||||
Parameter object for GPy models.
|
||||
|
||||
|
|
@ -42,10 +42,9 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
"""
|
||||
__array_priority__ = -1 # Never give back Param
|
||||
_fixes_ = None
|
||||
_parameters_ = []
|
||||
parameters = []
|
||||
def __new__(cls, name, input_array, default_constraint=None):
|
||||
obj = numpy.atleast_1d(super(Param, cls).__new__(cls, input_array=input_array))
|
||||
cls.__name__ = "Param"
|
||||
obj._current_slice_ = (slice(obj.shape[0]),)
|
||||
obj._realshape_ = obj.shape
|
||||
obj._realsize_ = obj.size
|
||||
|
|
@ -58,9 +57,9 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
|
||||
def build_pydot(self,G):
|
||||
import pydot
|
||||
node = pydot.Node(id(self), shape='record', label=self.name)
|
||||
node = pydot.Node(id(self), shape='trapezium', label=self.name)#, fontcolor='white', color='white')
|
||||
G.add_node(node)
|
||||
for o in self.observers.keys():
|
||||
for _, o, _ in self.observers:
|
||||
label = o.name if hasattr(o, 'name') else str(o)
|
||||
observed_node = pydot.Node(id(o), label=label)
|
||||
G.add_node(observed_node)
|
||||
|
|
@ -88,8 +87,18 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
|
||||
@property
|
||||
def param_array(self):
|
||||
"""
|
||||
As we are a leaf, this just returns self
|
||||
"""
|
||||
return self
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
"""
|
||||
Return self as numpy array view
|
||||
"""
|
||||
return self.view(np.ndarray)
|
||||
|
||||
@property
|
||||
def gradient(self):
|
||||
"""
|
||||
|
|
@ -100,11 +109,11 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
"""
|
||||
if getattr(self, '_gradient_array_', None) is None:
|
||||
self._gradient_array_ = numpy.empty(self._realshape_, dtype=numpy.float64)
|
||||
return self._gradient_array_[self._current_slice_]
|
||||
return self._gradient_array_#[self._current_slice_]
|
||||
|
||||
@gradient.setter
|
||||
def gradient(self, val):
|
||||
self._gradient_array_[self._current_slice_] = val
|
||||
self._gradient_array_[:] = val
|
||||
|
||||
#===========================================================================
|
||||
# Array operations -> done
|
||||
|
|
@ -112,10 +121,13 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
def __getitem__(self, s, *args, **kwargs):
|
||||
if not isinstance(s, tuple):
|
||||
s = (s,)
|
||||
if not reduce(lambda a, b: a or numpy.any(b is Ellipsis), s, False) and len(s) <= self.ndim:
|
||||
s += (Ellipsis,)
|
||||
#if not reduce(lambda a, b: a or numpy.any(b is Ellipsis), s, False) and len(s) <= self.ndim:
|
||||
# s += (Ellipsis,)
|
||||
new_arr = super(Param, self).__getitem__(s, *args, **kwargs)
|
||||
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base
|
||||
try:
|
||||
new_arr._current_slice_ = s
|
||||
new_arr._gradient_array_ = self.gradient[s]
|
||||
new_arr._original_ = self.base is new_arr.base
|
||||
except AttributeError: pass # returning 0d array or float, double etc
|
||||
return new_arr
|
||||
|
||||
|
|
@ -130,6 +142,9 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
def _raveled_index_for(self, obj):
|
||||
return self._raveled_index()
|
||||
|
||||
#===========================================================================
|
||||
# Index recreation
|
||||
#===========================================================================
|
||||
def _expand_index(self, slice_index=None):
|
||||
# this calculates the full indexing arrays from the slicing objects given by get_item for _real..._ attributes
|
||||
# it basically translates slices to their respective index arrays and turns negative indices around
|
||||
|
|
@ -138,6 +153,8 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
slice_index = self._current_slice_
|
||||
def f(a):
|
||||
a, b = a
|
||||
if isinstance(a, numpy.ndarray) and a.dtype == bool:
|
||||
raise ValueError, "Boolean indexing not implemented, use Param[np.where(index)] to index by boolean arrays!"
|
||||
if a not in (slice(None), Ellipsis):
|
||||
if isinstance(a, slice):
|
||||
start, stop, step = a.indices(b)
|
||||
|
|
@ -170,14 +187,24 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
#===========================================================================
|
||||
# Pickling and copying
|
||||
#===========================================================================
|
||||
def copy(self):
|
||||
return Parameterizable.copy(self, which=self)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
s = self.__new__(self.__class__, name=self.name, input_array=self.view(numpy.ndarray).copy())
|
||||
memo[id(self)] = s
|
||||
memo[id(self)] = s
|
||||
import copy
|
||||
s.__dict__.update(copy.deepcopy(self.__dict__, memo))
|
||||
Pickleable.__setstate__(s, copy.deepcopy(self.__getstate__(), memo))
|
||||
return s
|
||||
|
||||
|
||||
def _setup_observers(self):
|
||||
"""
|
||||
Setup the default observers
|
||||
|
||||
1: pass through to parent, if present
|
||||
"""
|
||||
if self.has_parent():
|
||||
self.add_observer(self._parent_, self._parent_._pass_through_notify_observers, -np.inf)
|
||||
|
||||
#===========================================================================
|
||||
# Printing -> done
|
||||
#===========================================================================
|
||||
|
|
@ -228,9 +255,16 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
and len(set(map(len, clean_curr_slice))) <= 1):
|
||||
return numpy.fromiter(itertools.izip(*clean_curr_slice),
|
||||
dtype=[('', int)] * self._realndim_, count=len(clean_curr_slice[0])).view((int, self._realndim_))
|
||||
expanded_index = list(self._expand_index(slice_index))
|
||||
return numpy.fromiter(itertools.product(*expanded_index),
|
||||
try:
|
||||
expanded_index = list(self._expand_index(slice_index))
|
||||
indices = numpy.fromiter(itertools.product(*expanded_index),
|
||||
dtype=[('', int)] * self._realndim_, count=reduce(lambda a, b: a * b.size, expanded_index, 1)).view((int, self._realndim_))
|
||||
except:
|
||||
print "Warning: extended indexing was used"
|
||||
indices = np.indices(self._realshape_, dtype=int)
|
||||
indices = indices[(slice(None),)+slice_index]
|
||||
indices = np.rollaxis(indices, 0, indices.ndim)
|
||||
return indices
|
||||
def _max_len_names(self, gen, header):
|
||||
gen = map(lambda x: " ".join(map(str, x)), gen)
|
||||
return reduce(lambda a, b:max(a, len(b)), gen, len(header))
|
||||
|
|
@ -272,7 +306,7 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
class ParamConcatenation(object):
|
||||
def __init__(self, params):
|
||||
"""
|
||||
Parameter concatenation for convienience of printing regular expression matched arrays
|
||||
Parameter concatenation for convenience of printing regular expression matched arrays
|
||||
you can index this concatenation as if it was the flattened concatenation
|
||||
of all the parameters it contains, same for setting parameters (Broadcasting enabled).
|
||||
|
||||
|
|
@ -316,8 +350,8 @@ class ParamConcatenation(object):
|
|||
val = val.values()
|
||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||
vals = self.values(); vals[s] = val
|
||||
[numpy.copyto(p, vals[ps], where=ind[ps])
|
||||
for p, ps in zip(self.params, self._param_slices_)]
|
||||
for p, ps in zip(self.params, self._param_slices_):
|
||||
p.flat[ind[ps]] = vals[ps]
|
||||
if update:
|
||||
self.update_all_params()
|
||||
def values(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue