mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
added absolute difference check to gradcheck
This commit is contained in:
parent
40c9790529
commit
914bdc73d8
1 changed files with 8 additions and 11 deletions
|
|
@ -359,10 +359,7 @@ class model(parameterised):
|
|||
numerical_gradient = (f1 - f2) / (2 * dx)
|
||||
global_ratio = (f1 - f2) / (2 * np.dot(dx, gradient))
|
||||
|
||||
if (np.abs(1. - global_ratio) < tolerance) and not np.isnan(global_ratio):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return (np.abs(1. - global_ratio) < tolerance) or (np.abs(gradient - numerical_gradient).mean() - 1) < tolerance
|
||||
else:
|
||||
# check the gradient of each parameter individually, and do some pretty printing
|
||||
try:
|
||||
|
|
@ -399,7 +396,7 @@ class model(parameterised):
|
|||
ratio = (f1 - f2) / (2 * step * gradient)
|
||||
difference = np.abs((f1 - f2) / 2 / step - gradient)
|
||||
|
||||
if (np.abs(ratio - 1) < tolerance):
|
||||
if (np.abs(1. - ratio) < tolerance) or np.abs(difference) < tolerance:
|
||||
formatted_name = "\033[92m {0} \033[0m".format(names[i])
|
||||
else:
|
||||
formatted_name = "\033[91m {0} \033[0m".format(names[i])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue