[priors] proprietary pickling of priors

This commit is contained in:
Max Zwiessele 2014-11-13 15:50:25 +00:00
parent 3c642a5600
commit cd98517599

View file

@ -69,6 +69,14 @@ class Gaussian(Prior):
def rvs(self, n):
return np.random.randn(n) * self.sigma + self.mu
def __getstate__(self):
return self.mu, self.sigma
def __setstate__(self, state):
self.mu = state[0]
self.sigma = state[1]
self.sigma2 = np.square(self.sigma)
self.constant = -0.5 * np.log(2 * np.pi * self.sigma2)
class Uniform(Prior):
domain = _REAL
@ -101,8 +109,14 @@ class Uniform(Prior):
def rvs(self, n):
return np.random.uniform(self.lower, self.upper, size=n)
def __getstate__(self):
return self.lower, self.upper
class LogGaussian(Prior):
def __setstate__(self, state):
self.lower = state[0]
self.upper = state[1]
class LogGaussian(Gaussian):
"""
Implementation of the univariate *log*-Gaussian probability function, coupled with random variables.
@ -202,6 +216,18 @@ class MultivariateGaussian:
priors_plots.multivariate_plot(self)
def __getstate__(self):
return self.mu, self.var
def __setstate__(self, state):
self.mu = state[0]
self.var = state[1]
assert len(self.var.shape) == 2
assert self.var.shape[0] == self.var.shape[1]
assert self.var.shape[0] == self.mu.size
self.input_dim = self.mu.size
self.inv, self.hld = pdinv(self.var)
self.constant = -0.5 * self.input_dim * np.log(2 * np.pi) - self.hld
def gamma_from_EV(E, V):
warnings.warn("use Gamma.from_EV to create Gamma Prior", FutureWarning)
@ -272,7 +298,15 @@ class Gamma(Prior):
b = E / V
return Gamma(a, b)
class InverseGamma(Prior):
def __getstate__(self):
return self.a, self.b
def __setstate__(self, state):
self.a = state[0]
self.b = state[1]
self.constant = -gammaln(self.a) + self.a * np.log(self.b)
class InverseGamma(Gamma):
"""
Implementation of the inverse-Gamma probability function, coupled with random variables.
@ -441,6 +475,21 @@ class DGPLVM_KFDA(Prior):
def __str__(self):
return 'DGPLVM_prior'
def __getstate___(self):
return self.lbl, self.lambdaa, self.sigma2, self.kern, self.x_shape
def __setstate__(self, state):
lbl, lambdaa, sigma2, kern, a, A, x_shape = state
self.datanum = lbl.shape[0]
self.classnum = lbl.shape[1]
self.lambdaa = lambdaa
self.sigma2 = sigma2
self.lbl = lbl
self.kern = kern
lst_ni = self.compute_lst_ni()
self.a = self.compute_a(lst_ni)
self.A = self.compute_A(lst_ni)
self.x_shape = x_shape
class DGPLVM(Prior):
"""