diff --git a/GPy/core/transformations.py b/GPy/core/transformations.py index 0524dfe3..d475f29e 100644 --- a/GPy/core/transformations.py +++ b/GPy/core/transformations.py @@ -5,15 +5,15 @@ import numpy as np from GPy.core.domains import POSITIVE, NEGATIVE, BOUNDED import sys +import weakref lim_val = -np.log(sys.float_info.epsilon) class Transformation(object): domain = None _instance = None def __new__(cls, *args, **kwargs): - if not cls._instance: - cls._instance = super(Transformation, cls).__new__( - cls, *args, **kwargs) + if not cls._instance or cls._instance.__class__ is not cls: + cls._instance = super(Transformation, cls).__new__(cls, *args, **kwargs) return cls._instance def f(self, x): raise NotImplementedError @@ -62,14 +62,16 @@ class LogexpClipped(Logexp): log_max_bound = np.log(max_bound) log_min_bound = np.log(min_bound) domain = POSITIVE + _instances = [] def __new__(cls, lower=1e-6, *args, **kwargs): - if not cls._instance: - cls._instance = super(Transformation, cls).__new__( - cls, lower, *args, **kwargs) - elif cls._instance.lower == lower: - return cls._instance - else: - return super(Transformation, cls).__new__(cls, lower, *args, **kwargs) + if cls._instances: + cls._instances[:] = [instance for instance in cls._instances if instance()] + for instance in cls._instances: + if instance().lower == lower: + return instance() + o = super(Transformation, cls).__new__(cls, lower, *args, **kwargs) + cls._instances.append(weakref.ref(o)) + return cls._instances[-1]() def __init__(self, lower=1e-6): self.lower = lower def f(self, x): @@ -136,19 +138,20 @@ class Square(Transformation): class Logistic(Transformation): domain = BOUNDED + _instances = [] def __new__(cls, lower=1e-6, upper=1e-6, *args, **kwargs): - if not cls._instance: - cls._instance = super(Transformation, cls).__new__( - cls, *args, **kwargs) - elif cls._instance.lower == lower and cls._instance.upper == upper: - return cls._instance - else: - return super(Transformation, cls).__new__(cls, *args, **kwargs) + if cls._instances: + cls._instances[:] = [instance for instance in cls._instances if instance()] + for instance in cls._instances: + if instance().lower == lower and instance().upper == upper: + return instance() + o = super(Transformation, cls).__new__(cls, lower, upper, *args, **kwargs) + cls._instances.append(weakref.ref(o)) + return cls._instances[-1]() def __init__(self, lower, upper): assert lower < upper self.lower, self.upper = float(lower), float(upper) self.difference = self.upper - self.lower - return self def f(self, x): return self.lower + self.difference / (1. + np.exp(-x)) def finv(self, f):