added absolute difference check to gradcheck

This commit is contained in:
Max Zwiessele 2013-05-03 13:35:41 +01:00
parent 40c9790529
commit 914bdc73d8

View file

@ -359,10 +359,7 @@ class model(parameterised):
numerical_gradient = (f1 - f2) / (2 * dx) numerical_gradient = (f1 - f2) / (2 * dx)
global_ratio = (f1 - f2) / (2 * np.dot(dx, gradient)) global_ratio = (f1 - f2) / (2 * np.dot(dx, gradient))
if (np.abs(1. - global_ratio) < tolerance) and not np.isnan(global_ratio): return (np.abs(1. - global_ratio) < tolerance) or (np.abs(gradient - numerical_gradient).mean() - 1) < tolerance
return True
else:
return False
else: else:
# check the gradient of each parameter individually, and do some pretty printing # check the gradient of each parameter individually, and do some pretty printing
try: try:
@ -399,7 +396,7 @@ class model(parameterised):
ratio = (f1 - f2) / (2 * step * gradient) ratio = (f1 - f2) / (2 * step * gradient)
difference = np.abs((f1 - f2) / 2 / step - gradient) difference = np.abs((f1 - f2) / 2 / step - gradient)
if (np.abs(ratio - 1) < tolerance): if (np.abs(1. - ratio) < tolerance) or np.abs(difference) < tolerance:
formatted_name = "\033[92m {0} \033[0m".format(names[i]) formatted_name = "\033[92m {0} \033[0m".format(names[i])
else: else:
formatted_name = "\033[91m {0} \033[0m".format(names[i]) formatted_name = "\033[91m {0} \033[0m".format(names[i])