mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-02 14:45:15 +02:00
draft: update prior __new__ functions
This commit is contained in:
parent
1fcb40845d
commit
200798eb48
2 changed files with 105 additions and 1 deletions
|
|
@ -1275,7 +1275,12 @@ class HalfT(Prior):
|
||||||
for instance in cls._instances:
|
for instance in cls._instances:
|
||||||
if instance().A == A and instance().nu == nu:
|
if instance().A == A and instance().nu == nu:
|
||||||
return instance()
|
return instance()
|
||||||
o = super(Prior, cls).__new__(cls, A, nu)
|
|
||||||
|
newfunc = super(Prior, cls).__new__
|
||||||
|
if newfunc is object.__new__:
|
||||||
|
o = newfunc(cls)
|
||||||
|
else:
|
||||||
|
o = newfunc(cls, A, nu)
|
||||||
cls._instances.append(weakref.ref(o))
|
cls._instances.append(weakref.ref(o))
|
||||||
return cls._instances[-1]()
|
return cls._instances[-1]()
|
||||||
|
|
||||||
|
|
|
||||||
99
GPy/testing/test_priors.py
Normal file
99
GPy/testing/test_priors.py
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
import numpy
|
||||||
|
from GPy.core.parameterization.priors import (
|
||||||
|
Gaussian,
|
||||||
|
Uniform,
|
||||||
|
LogGaussian,
|
||||||
|
MultivariateGaussian,
|
||||||
|
Gamma,
|
||||||
|
InverseGamma,
|
||||||
|
DGPLVM,
|
||||||
|
DGPLVM_KFDA,
|
||||||
|
DGPLVM_Lamda,
|
||||||
|
DGPLVM_T,
|
||||||
|
HalfT,
|
||||||
|
Exponential,
|
||||||
|
StudentT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_gaussian_prior() -> None:
|
||||||
|
return Gaussian(0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_uniform_prior() -> None:
|
||||||
|
return Uniform(0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_log_gaussian_prior() -> None:
|
||||||
|
return LogGaussian(0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_multivariate_gaussian_prior() -> None:
|
||||||
|
return MultivariateGaussian(numpy.zeros(2), numpy.eye(2))
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_gamma_prior() -> None:
|
||||||
|
return Gamma(1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_inverse_gamma_prior() -> None:
|
||||||
|
return InverseGamma(1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_dgplvm_prior() -> None:
|
||||||
|
return DGPLVM(1, 1, (1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_dgplvm_kfda_prior() -> None:
|
||||||
|
return DGPLVM_KFDA(1, 1, (1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_dgplvm_lamda_prior() -> None:
|
||||||
|
return DGPLVM_Lamda(1, 1, (1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_dgplvm_t_prior() -> None:
|
||||||
|
return DGPLVM_T(1, 1, (1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_half_t_prior() -> None:
|
||||||
|
return HalfT(1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_exponential_prior() -> None:
|
||||||
|
return Exponential(1)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_student_t_prior() -> None:
|
||||||
|
return StudentT(1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
PRIORS = {
|
||||||
|
"Gaussian": initialize_gaussian_prior,
|
||||||
|
"Uniform": initialize_uniform_prior,
|
||||||
|
"LogGaussian": initialize_log_gaussian_prior,
|
||||||
|
"MultivariateGaussian": initialize_multivariate_gaussian_prior,
|
||||||
|
"Gamma": initialize_gamma_prior,
|
||||||
|
"InverseGamma": initialize_inverse_gamma_prior,
|
||||||
|
"DGPLVM": initialize_dgplvm_prior,
|
||||||
|
"DGPLVM_KFDA": initialize_dgplvm_kfda_prior,
|
||||||
|
"DGPLVM_Lamda": initialize_dgplvm_lamda_prior,
|
||||||
|
"DGPLVM_T": initialize_dgplvm_t_prior,
|
||||||
|
"HalfT": initialize_half_t_prior,
|
||||||
|
"Exponential": initialize_exponential_prior,
|
||||||
|
"StudentT": initialize_student_t_prior,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def check_prior(prior_getter: str) -> None:
|
||||||
|
prior_getter()
|
||||||
|
|
||||||
|
|
||||||
|
def test_priors() -> None:
|
||||||
|
for prior_name, prior_getter in PRIORS.items():
|
||||||
|
try:
|
||||||
|
check_prior(prior_getter)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to initialize {prior_name} prior"
|
||||||
|
) from e # noqa E501
|
||||||
Loading…
Add table
Add a link
Reference in a new issue