[priors] pickling priors (not working for Discriminative prior)

This commit is contained in:
Max Zwiessele 2014-11-14 11:09:09 +00:00
parent c22042f5da
commit e7aac70c0a

View file

@ -12,6 +12,11 @@ import weakref
class Prior(object): class Prior(object):
domain = None domain = None
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance or cls._instance.__class__ is not cls:
cls._instance = super(Prior, cls).__new__(cls, *args, **kwargs)
return cls._instance
def pdf(self, x): def pdf(self, x):
return np.exp(self.lnpdf(x)) return np.exp(self.lnpdf(x))
@ -41,7 +46,7 @@ class Gaussian(Prior):
domain = _REAL domain = _REAL
_instances = [] _instances = []
def __new__(cls, mu, sigma): # Singleton: def __new__(cls, mu=0, sigma=1): # Singleton:
if cls._instances: if cls._instances:
cls._instances[:] = [instance for instance in cls._instances if instance()] cls._instances[:] = [instance for instance in cls._instances if instance()]
for instance in cls._instances: for instance in cls._instances:
@ -69,20 +74,20 @@ class Gaussian(Prior):
def rvs(self, n): def rvs(self, n):
return np.random.randn(n) * self.sigma + self.mu return np.random.randn(n) * self.sigma + self.mu
def __getstate__(self): # def __getstate__(self):
return self.mu, self.sigma # return self.mu, self.sigma
#
def __setstate__(self, state): # def __setstate__(self, state):
self.mu = state[0] # self.mu = state[0]
self.sigma = state[1] # self.sigma = state[1]
self.sigma2 = np.square(self.sigma) # self.sigma2 = np.square(self.sigma)
self.constant = -0.5 * np.log(2 * np.pi * self.sigma2) # self.constant = -0.5 * np.log(2 * np.pi * self.sigma2)
class Uniform(Prior): class Uniform(Prior):
domain = _REAL domain = _REAL
_instances = [] _instances = []
def __new__(cls, lower, upper): # Singleton: def __new__(cls, lower=0, upper=1): # Singleton:
if cls._instances: if cls._instances:
cls._instances[:] = [instance for instance in cls._instances if instance()] cls._instances[:] = [instance for instance in cls._instances if instance()]
for instance in cls._instances: for instance in cls._instances:
@ -109,12 +114,12 @@ class Uniform(Prior):
def rvs(self, n): def rvs(self, n):
return np.random.uniform(self.lower, self.upper, size=n) return np.random.uniform(self.lower, self.upper, size=n)
def __getstate__(self): # def __getstate__(self):
return self.lower, self.upper # return self.lower, self.upper
#
def __setstate__(self, state): # def __setstate__(self, state):
self.lower = state[0] # self.lower = state[0]
self.upper = state[1] # self.upper = state[1]
class LogGaussian(Gaussian): class LogGaussian(Gaussian):
""" """
@ -129,7 +134,7 @@ class LogGaussian(Gaussian):
domain = _POSITIVE domain = _POSITIVE
_instances = [] _instances = []
def __new__(cls, mu, sigma): # Singleton: def __new__(cls, mu=0, sigma=1): # Singleton:
if cls._instances: if cls._instances:
cls._instances[:] = [instance for instance in cls._instances if instance()] cls._instances[:] = [instance for instance in cls._instances if instance()]
for instance in cls._instances: for instance in cls._instances:
@ -158,7 +163,7 @@ class LogGaussian(Gaussian):
return np.exp(np.random.randn(n) * self.sigma + self.mu) return np.exp(np.random.randn(n) * self.sigma + self.mu)
class MultivariateGaussian: class MultivariateGaussian(Prior):
""" """
Implementation of the multivariate Gaussian probability function, coupled with random variables. Implementation of the multivariate Gaussian probability function, coupled with random variables.
@ -171,7 +176,7 @@ class MultivariateGaussian:
domain = _REAL domain = _REAL
_instances = [] _instances = []
def __new__(cls, mu, var): # Singleton: def __new__(cls, mu=0, var=1): # Singleton:
if cls._instances: if cls._instances:
cls._instances[:] = [instance for instance in cls._instances if instance()] cls._instances[:] = [instance for instance in cls._instances if instance()]
for instance in cls._instances: for instance in cls._instances:
@ -247,7 +252,7 @@ class Gamma(Prior):
domain = _POSITIVE domain = _POSITIVE
_instances = [] _instances = []
def __new__(cls, a, b): # Singleton: def __new__(cls, a=1, b=.5): # Singleton:
if cls._instances: if cls._instances:
cls._instances[:] = [instance for instance in cls._instances if instance()] cls._instances[:] = [instance for instance in cls._instances if instance()]
for instance in cls._instances: for instance in cls._instances:
@ -318,7 +323,7 @@ class InverseGamma(Gamma):
""" """
domain = _POSITIVE domain = _POSITIVE
_instances = [] _instances = []
def __new__(cls, a, b): # Singleton: def __new__(cls, a=1, b=.5): # Singleton:
if cls._instances: if cls._instances:
cls._instances[:] = [instance for instance in cls._instances if instance()] cls._instances[:] = [instance for instance in cls._instances if instance()]
for instance in cls._instances: for instance in cls._instances: