gradcheck global diff

This commit is contained in:
Max Zwiessele 2014-03-05 12:44:53 +00:00
parent cde8722e1b
commit f5e3e97794

View file

@ -298,9 +298,11 @@ class Model(Parameterized):
dx = dx[transformed_index] dx = dx[transformed_index]
gradient = gradient[transformed_index] gradient = gradient[transformed_index]
global_ratio = (f1 - f2) / (2 * np.dot(dx, np.where(gradient == 0, 1e-32, gradient))) global_ratio = (f1 - f2) / (2 * np.dot(dx, np.where(gradient == 0, 1e-32, gradient)))
return (np.abs(1. - global_ratio) < tolerance) num_grad =(np.abs((f1-f2)/-(2*dx)*np.where(gradient == 0, 1e-32, gradient))).mean()
return (np.abs(1. - global_ratio) < tolerance) or (num_grad < tolerance)
else: else:
# check the gradient of each parameter individually, and do some pretty printing # check the gradient of each parameter individually, and do some pretty printing
try: try: