mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 19:42:39 +02:00
checkgrad (╯°□°)╯︵ ┻━┻
This commit is contained in:
parent
9af4c34f90
commit
ffd09c7820
4 changed files with 9 additions and 5 deletions
|
|
@ -43,6 +43,7 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parameteri
|
|||
_parameters_ = []
|
||||
def __new__(cls, name, input_array, default_constraint=None):
|
||||
obj = numpy.atleast_1d(super(Param, cls).__new__(cls, input_array=input_array))
|
||||
cls.__name__ = "Param"
|
||||
obj._current_slice_ = (slice(obj.shape[0]),)
|
||||
obj._realshape_ = obj.shape
|
||||
obj._realsize_ = obj.size
|
||||
|
|
@ -57,7 +58,7 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parameteri
|
|||
|
||||
def __init__(self, name, input_array, default_constraint=None):
|
||||
super(Param, self).__init__(name=name, default_constraint=default_constraint)
|
||||
|
||||
|
||||
def __array_finalize__(self, obj):
|
||||
# see InfoArray.__array_finalize__ for comments
|
||||
if obj is None: return
|
||||
|
|
@ -75,6 +76,7 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parameteri
|
|||
self._original_ = getattr(obj, '_original_', None)
|
||||
self._name = getattr(obj, 'name', None)
|
||||
self.gradient = getattr(obj, 'gradient', None)
|
||||
self.constraints = getattr(obj, 'constraints', None)
|
||||
|
||||
def __array_wrap__(self, out_arr, context=None):
|
||||
return out_arr.view(numpy.ndarray)
|
||||
|
|
@ -391,6 +393,9 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parameteri
|
|||
slice_index = self._current_slice_
|
||||
if isinstance(slice_index, (tuple, list)):
|
||||
clean_curr_slice = [s for s in slice_index if numpy.any(s != Ellipsis)]
|
||||
for i in range(self._realndim_-len(clean_curr_slice)):
|
||||
i+=len(clean_curr_slice)
|
||||
clean_curr_slice += range(self._realshape_[i])
|
||||
if (all(isinstance(n, (numpy.ndarray, list, tuple)) for n in clean_curr_slice)
|
||||
and len(set(map(len, clean_curr_slice))) <= 1):
|
||||
return numpy.fromiter(itertools.izip(*clean_curr_slice),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue