checkgrad now global and callable from any parameter

This commit is contained in:
Max Zwiessele 2014-02-10 16:01:55 +00:00
parent 6a068775f5
commit 826d2f04ff
4 changed files with 27 additions and 14 deletions

View file

@ -8,7 +8,7 @@ import cPickle
import itertools
from re import compile, _pattern_type
from param import ParamConcatenation, Param
from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing
from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing, Gradcheckable
from index_operations import ParameterIndexOperations,\
index_empty
from array_core import ParamList
@ -24,7 +24,7 @@ FIXED = False
UNFIXED = True
#===============================================================================
class Parameterized(Constrainable, Pickleable, Observable):
class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
"""
Parameterized class
@ -230,7 +230,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
elif not (pname in not_unique):
self.__dict__[pname] = p
self._added_names_.add(pname)
#===========================================================================
# Pickling operations
#===========================================================================
@ -385,7 +385,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
"""
return numpy.r_[:self.size]
#===========================================================================
# Handle ties:
# Fixing parameters:
#===========================================================================
def _set_fixed(self, param_or_index):
if not self._has_fixes(): self._fixes_ = numpy.ones(self.size, dtype=bool)
@ -410,9 +410,6 @@ class Parameterized(Constrainable, Pickleable, Observable):
if self._has_fixes():
return self._fixes_[self._raveled_index_for(param)]
return numpy.ones(self.size, dtype=bool)[self._raveled_index_for(param)]
#===========================================================================
# Fixing parameters:
#===========================================================================
def _fix(self, param, warning=True):
f = self._add_constrain(param, __fixed__, warning)
self._set_fixed(f)