From 6a6cbb58c856af04cd2743fafdd84d7f40a2328c Mon Sep 17 00:00:00 2001 From: Nicolo Fusi Date: Tue, 5 Feb 2013 16:40:26 +0000 Subject: [PATCH] sometimes a step with missing data can be a bit unstable --- GPy/inference/SGD.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/GPy/inference/SGD.py b/GPy/inference/SGD.py index a12b808e..29e2072d 100644 --- a/GPy/inference/SGD.py +++ b/GPy/inference/SGD.py @@ -141,7 +141,11 @@ class opt_SGD(Optimizer): b, p = self.shift_constraints(j) momentum_term = self.momentum * step[j] - f, fp = f_fp(self.x_opt[j]) + + try: + f, fp = f_fp(self.x_opt[j]) + except Exception: + return 0, step, self.model.N step[j] = self.learning_rate[j] * fp self.x_opt[j] -= step[j] + momentum_term @@ -171,6 +175,7 @@ class opt_SGD(Optimizer): features = [features[i::b] for i in range(b)] NLL = [] count = 0 + last_printed_count = -1 for j in features: count += 1 @@ -188,9 +193,11 @@ class opt_SGD(Optimizer): self.x_opt -= step + momentum_term if self.messages == 2: - status = "evaluating {feature: 5d}/{tot: 5d} \t f: {f: 2.3f} \t non-missing: {nm: 4d}\r".format(feature = count, tot = len(features), f = f, nm = Nj) - sys.stdout.write(status) - sys.stdout.flush() + if count == last_printed_count + 20 or count == 0: + status = "evaluating {feature: 5d}/{tot: 5d} \t f: {f: 2.3f} \t non-missing: {nm: 4d}\r".format(feature = count, tot = len(features), f = f, nm = Nj) + sys.stdout.write(status) + sys.stdout.flush() + last_printed_count = count NLL.append(f)