From 200798eb487e5f111cfa2aea8c9c64d39c2f844a Mon Sep 17 00:00:00 2001 From: Martin Bubel Date: Mon, 28 Oct 2024 23:06:31 +0100 Subject: [PATCH] draft: update prior __new__ functions --- GPy/core/parameterization/priors.py | 7 +- GPy/testing/test_priors.py | 99 +++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 GPy/testing/test_priors.py diff --git a/GPy/core/parameterization/priors.py b/GPy/core/parameterization/priors.py index 3550a8b5..bdb08415 100644 --- a/GPy/core/parameterization/priors.py +++ b/GPy/core/parameterization/priors.py @@ -1275,7 +1275,12 @@ class HalfT(Prior): for instance in cls._instances: if instance().A == A and instance().nu == nu: return instance() - o = super(Prior, cls).__new__(cls, A, nu) + + newfunc = super(Prior, cls).__new__ + if newfunc is object.__new__: + o = newfunc(cls) + else: + o = newfunc(cls, A, nu) cls._instances.append(weakref.ref(o)) return cls._instances[-1]() diff --git a/GPy/testing/test_priors.py b/GPy/testing/test_priors.py new file mode 100644 index 00000000..51e357e3 --- /dev/null +++ b/GPy/testing/test_priors.py @@ -0,0 +1,99 @@ +import numpy +from GPy.core.parameterization.priors import ( + Gaussian, + Uniform, + LogGaussian, + MultivariateGaussian, + Gamma, + InverseGamma, + DGPLVM, + DGPLVM_KFDA, + DGPLVM_Lamda, + DGPLVM_T, + HalfT, + Exponential, + StudentT, +) + + +def initialize_gaussian_prior() -> None: + return Gaussian(0, 1) + + +def initialize_uniform_prior() -> None: + return Uniform(0, 1) + + +def initialize_log_gaussian_prior() -> None: + return LogGaussian(0, 1) + + +def initialize_multivariate_gaussian_prior() -> None: + return MultivariateGaussian(numpy.zeros(2), numpy.eye(2)) + + +def initialize_gamma_prior() -> None: + return Gamma(1, 1) + + +def initialize_inverse_gamma_prior() -> None: + return InverseGamma(1, 1) + + +def initialize_dgplvm_prior() -> None: + return DGPLVM(1, 1, (1, 1)) + + +def initialize_dgplvm_kfda_prior() -> None: + return DGPLVM_KFDA(1, 1, (1, 1)) + + +def initialize_dgplvm_lamda_prior() -> None: + return DGPLVM_Lamda(1, 1, (1, 1)) + + +def initialize_dgplvm_t_prior() -> None: + return DGPLVM_T(1, 1, (1, 1)) + + +def initialize_half_t_prior() -> None: + return HalfT(1, 1) + + +def initialize_exponential_prior() -> None: + return Exponential(1) + + +def initialize_student_t_prior() -> None: + return StudentT(1, 1, 1) + + +PRIORS = { + "Gaussian": initialize_gaussian_prior, + "Uniform": initialize_uniform_prior, + "LogGaussian": initialize_log_gaussian_prior, + "MultivariateGaussian": initialize_multivariate_gaussian_prior, + "Gamma": initialize_gamma_prior, + "InverseGamma": initialize_inverse_gamma_prior, + "DGPLVM": initialize_dgplvm_prior, + "DGPLVM_KFDA": initialize_dgplvm_kfda_prior, + "DGPLVM_Lamda": initialize_dgplvm_lamda_prior, + "DGPLVM_T": initialize_dgplvm_t_prior, + "HalfT": initialize_half_t_prior, + "Exponential": initialize_exponential_prior, + "StudentT": initialize_student_t_prior, +} + + +def check_prior(prior_getter: str) -> None: + prior_getter() + + +def test_priors() -> None: + for prior_name, prior_getter in PRIORS.items(): + try: + check_prior(prior_getter) + except Exception as e: + raise RuntimeError( + f"Failed to initialize {prior_name} prior" + ) from e # noqa E501