transformations singleton

This commit is contained in:
Max Zwiessele 2013-09-20 17:21:28 +01:00
parent c2d217e72c
commit 4e102a859b

View file

@ -9,20 +9,22 @@ lim_val = -np.log(sys.float_info.epsilon)
class Transformation(object):
domain = None
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(Transformation, cls).__new__(
cls, *args, **kwargs)
return cls._instance
def f(self, x):
raise NotImplementedError
def finv(self, x):
raise NotImplementedError
def gradfactor(self, f):
""" df_dx evaluated at self.f(x)=f"""
raise NotImplementedError
def initialize(self, f):
""" produce a sensible initial value for f(x)"""
raise NotImplementedError
def __str__(self):
raise NotImplementedError
@ -60,6 +62,14 @@ class LogexpClipped(Logexp):
log_max_bound = np.log(max_bound)
log_min_bound = np.log(min_bound)
domain = POSITIVE
def __new__(cls, lower=1e-6, *args, **kwargs):
if not cls._instance:
cls._instance = super(Transformation, cls).__new__(
cls, lower, *args, **kwargs)
elif cls._instance.lower == lower:
return cls._instance
else:
return super(Transformation, cls).__new__(cls, lower, *args, **kwargs)
def __init__(self, lower=1e-6):
self.lower = lower
def f(self, x):
@ -81,6 +91,7 @@ class LogexpClipped(Logexp):
def __str__(self):
return '(+ve_c)'
class Exponent(Transformation):
# TODO: can't allow this to go to zero, need to set a lower bound. Similar with negative Exponent below. See old MATLAB code.
domain = POSITIVE
@ -125,10 +136,19 @@ class Square(Transformation):
class Logistic(Transformation):
domain = BOUNDED
def __new__(cls, lower=1e-6, upper=1e-6, *args, **kwargs):
if not cls._instance:
cls._instance = super(Transformation, cls).__new__(
cls, *args, **kwargs)
elif cls._instance.lower == lower and cls._instance.upper == upper:
return cls._instance
else:
return super(Transformation, cls).__new__(cls, *args, **kwargs)
def __init__(self, lower, upper):
assert lower < upper
self.lower, self.upper = float(lower), float(upper)
self.difference = self.upper - self.lower
return self
def f(self, x):
return self.lower + self.difference / (1. + np.exp(-x))
def finv(self, f):
@ -142,3 +162,5 @@ class Logistic(Transformation):
def __str__(self):
return '({},{})'.format(self.lower, self.upper)