diff --git a/GPy/core/svigp.py b/GPy/core/svigp.py index 7d2bec36..8a2d889d 100644 --- a/GPy/core/svigp.py +++ b/GPy/core/svigp.py @@ -92,12 +92,56 @@ class SVIGP(GPBase): self._vb_steplength_trace = [] def getstate(self): - return GPBase.getstate(self) - + steplength_params = [self.hbar_t, self.tau_t, self.gbar_t, self.gbar_t1, self.gbar_t2, self.hbar_tp, self.tau_tp, self.gbar_tp, self.adapt_param_steplength, self.adapt_vb_steplength, self.vb_steplength, self.param_steplength] + return GPBase.getstate(self) + \ + [self.get_vb_param(), + self.Z, + self.num_inducing, + self.has_uncertain_inputs, + self.X_variance, + self.X_batch, + self.X_variance_batch, + steplength_params, + self.batchcounter, + self.batchsize, + self.epochs, + self.momentum, + self.data_prop, + self._param_trace, + self._param_steplength_trace, + self._vb_steplength_trace, + self._ll_trace, + self._grad_trace, + self.Y, + self._permutation, + self.iterations + ] def setstate(self, state): - return GPBase.setstate(self, state) - + self.iterations = state.pop() + self._permutation = state.pop() + self.Y = state.pop() + self._grad_trace = state.pop() + self._ll_trace = state.pop() + self._vb_steplength_trace = state.pop() + self._param_steplength_trace = state.pop() + self._param_trace = state.pop() + self.data_prop = state.pop() + self.momentum = state.pop() + self.epochs = state.pop() + self.batchsize = state.pop() + self.batchcounter = state.pop() + steplength_params = state.pop() + (self.hbar_t, self.tau_t, self.gbar_t, self.gbar_t1, self.gbar_t2, self.hbar_tp, self.tau_tp, self.gbar_tp, self.adapt_param_steplength, self.adapt_vb_steplength, self.vb_steplength, self.param_steplength) = steplength_params + self.X_variance_batch = state.pop() + self.X_batch = state.pop() + self.X_variance = state.pop() + self.has_uncertain_inputs = state.pop() + self.num_inducing = state.pop() + self.Z = state.pop() + vb_param = state.pop() + GPBase.setstate(self, state) + self.set_vb_param(vb_param) def _compute_kernel_matrices(self): # kernel computations, using BGPLVM notation