mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
add serialization functions for EPDTC
This commit is contained in:
parent
592daced9b
commit
f4ebae7425
2 changed files with 73 additions and 0 deletions
|
|
@ -132,6 +132,13 @@ class posteriorParamsDTC(posteriorParamsBase):
|
|||
self.mu += (delta_v-delta_tau*self.mu[i])*si
|
||||
#mu = np.dot(Sigma, v_tilde)
|
||||
|
||||
def to_dict(self):
|
||||
return { "mu": self.mu.tolist(), "Sigma_diag": self.Sigma_diag.tolist()}
|
||||
|
||||
@staticmethod
|
||||
def from_dict(input_dict):
|
||||
return posteriorParamsDTC(np.array(input_dict["mu"]), np.array(input_dict["Sigma_diag"]))
|
||||
|
||||
@staticmethod
|
||||
def _recompute(LLT0, Kmn, ga_approx):
|
||||
LLT = LLT0 + np.dot(Kmn*ga_approx.tau[None,:],Kmn.T)
|
||||
|
|
@ -533,3 +540,35 @@ class EPDTC(EPBase, VarDTC):
|
|||
#Posterior distribution parameters update
|
||||
if self.parallel_updates == False:
|
||||
post_params._update_rank1(LLT, Kmn, delta_v, delta_tau, i)
|
||||
|
||||
|
||||
def to_dict(self):
|
||||
input_dict = super(EPDTC, self)._to_dict()
|
||||
input_dict["class"] = "GPy.inference.latent_function_inference.expectation_propagation.EPDTC"
|
||||
if self.ga_approx_old is not None:
|
||||
input_dict["ga_approx_old"] = self.ga_approx_old.to_dict()
|
||||
if self._ep_approximation is not None:
|
||||
input_dict["_ep_approximation"] = {}
|
||||
input_dict["_ep_approximation"]["post_params"] = self._ep_approximation[0].to_dict()
|
||||
input_dict["_ep_approximation"]["ga_approx"] = self._ep_approximation[1].to_dict()
|
||||
input_dict["_ep_approximation"]["cav_params"] = self._ep_approximation[2].to_dict()
|
||||
input_dict["_ep_approximation"]["log_Z_tilde"] = self._ep_approximation[3].tolist()
|
||||
|
||||
return input_dict
|
||||
|
||||
@staticmethod
|
||||
def _from_dict(inference_class, input_dict):
|
||||
ga_approx_old = input_dict.pop('ga_approx_old', None)
|
||||
if ga_approx_old is not None:
|
||||
ga_approx_old = gaussianApproximation.from_dict(ga_approx_old)
|
||||
_ep_approximation_dict = input_dict.pop('_ep_approximation', None)
|
||||
_ep_approximation = []
|
||||
if _ep_approximation is not None:
|
||||
_ep_approximation.append(posteriorParamsDTC.from_dict(_ep_approximation_dict["post_params"]))
|
||||
_ep_approximation.append(gaussianApproximation.from_dict(_ep_approximation_dict["ga_approx"]))
|
||||
_ep_approximation.append(cavityParams.from_dict(_ep_approximation_dict["cav_params"]))
|
||||
_ep_approximation.append(np.array(_ep_approximation_dict["log_Z_tilde"]))
|
||||
ee = EPDTC(**input_dict)
|
||||
ee.ga_approx_old = ga_approx_old
|
||||
ee._ep_approximation = _ep_approximation
|
||||
return ee
|
||||
|
|
|
|||
|
|
@ -116,11 +116,45 @@ class Test(unittest.TestCase):
|
|||
np.testing.assert_array_equal(e1._ep_approximation[2].v[:], e1_r._ep_approximation[2].v[:])
|
||||
np.testing.assert_array_equal(e1._ep_approximation[3][:], e1_r._ep_approximation[3][:])
|
||||
|
||||
|
||||
e1 = GPy.inference.latent_function_inference.expectation_propagation.EPDTC(ep_mode="nested")
|
||||
e1.ga_approx_old = GPy.inference.latent_function_inference.expectation_propagation.gaussianApproximation(np.random.rand(10),np.random.rand(10))
|
||||
e1._ep_approximation = []
|
||||
e1._ep_approximation.append(GPy.inference.latent_function_inference.expectation_propagation.posteriorParamsDTC(np.random.rand(10),np.random.rand(10)))
|
||||
e1._ep_approximation.append(GPy.inference.latent_function_inference.expectation_propagation.gaussianApproximation(np.random.rand(10),np.random.rand(10)))
|
||||
e1._ep_approximation.append(GPy.inference.latent_function_inference.expectation_propagation.cavityParams(10))
|
||||
e1._ep_approximation[-1].v = np.random.rand(10)
|
||||
e1._ep_approximation[-1].tau = np.random.rand(10)
|
||||
e1._ep_approximation.append(np.random.rand(10))
|
||||
e1_r = GPy.inference.latent_function_inference.LatentFunctionInference.from_dict(e1.to_dict())
|
||||
|
||||
|
||||
assert type(e1) == type(e1_r)
|
||||
assert e1.epsilon==e1_r.epsilon
|
||||
assert e1.eta==e1_r.eta
|
||||
assert e1.delta==e1_r.delta
|
||||
assert e1.always_reset==e1_r.always_reset
|
||||
assert e1.max_iters==e1_r.max_iters
|
||||
assert e1.ep_mode==e1_r.ep_mode
|
||||
assert e1.parallel_updates==e1_r.parallel_updates
|
||||
|
||||
np.testing.assert_array_equal(e1.ga_approx_old.tau[:], e1_r.ga_approx_old.tau[:])
|
||||
np.testing.assert_array_equal(e1.ga_approx_old.v[:], e1_r.ga_approx_old.v[:])
|
||||
np.testing.assert_array_equal(e1._ep_approximation[0].mu[:], e1_r._ep_approximation[0].mu[:])
|
||||
np.testing.assert_array_equal(e1._ep_approximation[0].Sigma_diag[:], e1_r._ep_approximation[0].Sigma_diag[:])
|
||||
np.testing.assert_array_equal(e1._ep_approximation[1].tau[:], e1_r._ep_approximation[1].tau[:])
|
||||
np.testing.assert_array_equal(e1._ep_approximation[1].v[:], e1_r._ep_approximation[1].v[:])
|
||||
np.testing.assert_array_equal(e1._ep_approximation[2].tau[:], e1_r._ep_approximation[2].tau[:])
|
||||
np.testing.assert_array_equal(e1._ep_approximation[2].v[:], e1_r._ep_approximation[2].v[:])
|
||||
np.testing.assert_array_equal(e1._ep_approximation[3][:], e1_r._ep_approximation[3][:])
|
||||
|
||||
|
||||
e2 = GPy.inference.latent_function_inference.exact_gaussian_inference.ExactGaussianInference()
|
||||
e2_r = GPy.inference.latent_function_inference.LatentFunctionInference.from_dict(e2.to_dict())
|
||||
|
||||
assert type(e2) == type(e2_r)
|
||||
|
||||
|
||||
def test_serialize_deserialize_model(self):
|
||||
np.random.seed(fixed_seed)
|
||||
N = 20
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue