allowed the gradchecker to return the gradient ratio

Just to help with debugging.
This commit is contained in:
James Hensman 2013-01-11 17:58:19 +00:00
parent 1a135ca9f7
commit c949da2d67

View file

@ -286,7 +286,7 @@ class model(parameterised):
return '\n'.join(s) return '\n'.join(s)
def checkgrad(self, verbose=False, include_priors=False, step=1e-6, tolerance = 1e-3, *args): def checkgrad(self, verbose=False, include_priors=False, step=1e-6, tolerance = 1e-3, return_ratio=False, *args):
""" """
Check the gradient of the model by comparing to a numerical estimate. Check the gradient of the model by comparing to a numerical estimate.
If the overall gradient fails, invividual components are tested. If the overall gradient fails, invividual components are tested.
@ -306,12 +306,12 @@ class model(parameterised):
gradient = self.extract_gradients() gradient = self.extract_gradients()
numerical_gradient = (f1-f2)/(2*dx) numerical_gradient = (f1-f2)/(2*dx)
ratio = (f1-f2)/(2*np.dot(dx,gradient)) global_ratio = (f1-f2)/(2*np.dot(dx,gradient))
if verbose: if verbose:
print "Gradient ratio = ", ratio, '\n' print "Gradient ratio = ", global_ratio, '\n'
sys.stdout.flush() sys.stdout.flush()
if (np.abs(1.-ratio)<tolerance) and not np.isnan(ratio): if (np.abs(1.-global_ratio)<tolerance) and not np.isnan(global_ratio):
if verbose: if verbose:
print 'Gradcheck passed' print 'Gradcheck passed'
else: else:
@ -366,5 +366,12 @@ class model(parameterised):
if verbose: if verbose:
print '' print ''
if return_ratio:
return global_ratio
else:
return False return False
if return_ratio:
return global_ratio
else:
return True return True