From f4ebae742595eff2c95a68d583a44d0920b39b26 Mon Sep 17 00:00:00 2001 From: Moreno Date: Mon, 5 Mar 2018 12:14:31 +0000 Subject: [PATCH] add serialization functions for EPDTC --- .../expectation_propagation.py | 39 +++++++++++++++++++ GPy/testing/serialization_tests.py | 34 ++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/GPy/inference/latent_function_inference/expectation_propagation.py b/GPy/inference/latent_function_inference/expectation_propagation.py index e92b58cb..61d3feff 100644 --- a/GPy/inference/latent_function_inference/expectation_propagation.py +++ b/GPy/inference/latent_function_inference/expectation_propagation.py @@ -132,6 +132,13 @@ class posteriorParamsDTC(posteriorParamsBase): self.mu += (delta_v-delta_tau*self.mu[i])*si #mu = np.dot(Sigma, v_tilde) + def to_dict(self): + return { "mu": self.mu.tolist(), "Sigma_diag": self.Sigma_diag.tolist()} + + @staticmethod + def from_dict(input_dict): + return posteriorParamsDTC(np.array(input_dict["mu"]), np.array(input_dict["Sigma_diag"])) + @staticmethod def _recompute(LLT0, Kmn, ga_approx): LLT = LLT0 + np.dot(Kmn*ga_approx.tau[None,:],Kmn.T) @@ -533,3 +540,35 @@ class EPDTC(EPBase, VarDTC): #Posterior distribution parameters update if self.parallel_updates == False: post_params._update_rank1(LLT, Kmn, delta_v, delta_tau, i) + + + def to_dict(self): + input_dict = super(EPDTC, self)._to_dict() + input_dict["class"] = "GPy.inference.latent_function_inference.expectation_propagation.EPDTC" + if self.ga_approx_old is not None: + input_dict["ga_approx_old"] = self.ga_approx_old.to_dict() + if self._ep_approximation is not None: + input_dict["_ep_approximation"] = {} + input_dict["_ep_approximation"]["post_params"] = self._ep_approximation[0].to_dict() + input_dict["_ep_approximation"]["ga_approx"] = self._ep_approximation[1].to_dict() + input_dict["_ep_approximation"]["cav_params"] = self._ep_approximation[2].to_dict() + input_dict["_ep_approximation"]["log_Z_tilde"] = self._ep_approximation[3].tolist() + + return input_dict + + @staticmethod + def _from_dict(inference_class, input_dict): + ga_approx_old = input_dict.pop('ga_approx_old', None) + if ga_approx_old is not None: + ga_approx_old = gaussianApproximation.from_dict(ga_approx_old) + _ep_approximation_dict = input_dict.pop('_ep_approximation', None) + _ep_approximation = [] + if _ep_approximation is not None: + _ep_approximation.append(posteriorParamsDTC.from_dict(_ep_approximation_dict["post_params"])) + _ep_approximation.append(gaussianApproximation.from_dict(_ep_approximation_dict["ga_approx"])) + _ep_approximation.append(cavityParams.from_dict(_ep_approximation_dict["cav_params"])) + _ep_approximation.append(np.array(_ep_approximation_dict["log_Z_tilde"])) + ee = EPDTC(**input_dict) + ee.ga_approx_old = ga_approx_old + ee._ep_approximation = _ep_approximation + return ee diff --git a/GPy/testing/serialization_tests.py b/GPy/testing/serialization_tests.py index 80dfd219..7eb3fe5c 100644 --- a/GPy/testing/serialization_tests.py +++ b/GPy/testing/serialization_tests.py @@ -116,11 +116,45 @@ class Test(unittest.TestCase): np.testing.assert_array_equal(e1._ep_approximation[2].v[:], e1_r._ep_approximation[2].v[:]) np.testing.assert_array_equal(e1._ep_approximation[3][:], e1_r._ep_approximation[3][:]) + + e1 = GPy.inference.latent_function_inference.expectation_propagation.EPDTC(ep_mode="nested") + e1.ga_approx_old = GPy.inference.latent_function_inference.expectation_propagation.gaussianApproximation(np.random.rand(10),np.random.rand(10)) + e1._ep_approximation = [] + e1._ep_approximation.append(GPy.inference.latent_function_inference.expectation_propagation.posteriorParamsDTC(np.random.rand(10),np.random.rand(10))) + e1._ep_approximation.append(GPy.inference.latent_function_inference.expectation_propagation.gaussianApproximation(np.random.rand(10),np.random.rand(10))) + e1._ep_approximation.append(GPy.inference.latent_function_inference.expectation_propagation.cavityParams(10)) + e1._ep_approximation[-1].v = np.random.rand(10) + e1._ep_approximation[-1].tau = np.random.rand(10) + e1._ep_approximation.append(np.random.rand(10)) + e1_r = GPy.inference.latent_function_inference.LatentFunctionInference.from_dict(e1.to_dict()) + + + assert type(e1) == type(e1_r) + assert e1.epsilon==e1_r.epsilon + assert e1.eta==e1_r.eta + assert e1.delta==e1_r.delta + assert e1.always_reset==e1_r.always_reset + assert e1.max_iters==e1_r.max_iters + assert e1.ep_mode==e1_r.ep_mode + assert e1.parallel_updates==e1_r.parallel_updates + + np.testing.assert_array_equal(e1.ga_approx_old.tau[:], e1_r.ga_approx_old.tau[:]) + np.testing.assert_array_equal(e1.ga_approx_old.v[:], e1_r.ga_approx_old.v[:]) + np.testing.assert_array_equal(e1._ep_approximation[0].mu[:], e1_r._ep_approximation[0].mu[:]) + np.testing.assert_array_equal(e1._ep_approximation[0].Sigma_diag[:], e1_r._ep_approximation[0].Sigma_diag[:]) + np.testing.assert_array_equal(e1._ep_approximation[1].tau[:], e1_r._ep_approximation[1].tau[:]) + np.testing.assert_array_equal(e1._ep_approximation[1].v[:], e1_r._ep_approximation[1].v[:]) + np.testing.assert_array_equal(e1._ep_approximation[2].tau[:], e1_r._ep_approximation[2].tau[:]) + np.testing.assert_array_equal(e1._ep_approximation[2].v[:], e1_r._ep_approximation[2].v[:]) + np.testing.assert_array_equal(e1._ep_approximation[3][:], e1_r._ep_approximation[3][:]) + + e2 = GPy.inference.latent_function_inference.exact_gaussian_inference.ExactGaussianInference() e2_r = GPy.inference.latent_function_inference.LatentFunctionInference.from_dict(e2.to_dict()) assert type(e2) == type(e2_r) + def test_serialize_deserialize_model(self): np.random.seed(fixed_seed) N = 20