mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-04-24 20:36:23 +02:00
update prior tests
This commit is contained in:
parent
200798eb48
commit
5a9e5e3cc4
3 changed files with 109 additions and 1 deletions
|
|
@ -1,6 +1,8 @@
|
|||
# Changelog
|
||||
|
||||
## Unreleased
|
||||
* update prior `__new__` methods #1098 [MartinBubel]
|
||||
|
||||
* fix invalid escape sequence #1011 [janmayer]
|
||||
|
||||
## v1.13.2 (2024-07-21)
|
||||
|
|
|
|||
|
|
@ -580,7 +580,11 @@ class DGPLVM(Prior):
|
|||
domain = _REAL
|
||||
|
||||
def __new__(cls, sigma2, lbl, x_shape):
|
||||
return super(Prior, cls).__new__(cls, sigma2, lbl, x_shape)
|
||||
newfunc = super(Prior, cls).__new__
|
||||
if newfunc is object.__new__:
|
||||
return newfunc(cls)
|
||||
else:
|
||||
return newfunc(cls, sigma2, lbl, x_shape)
|
||||
|
||||
def __init__(self, sigma2, lbl, x_shape):
|
||||
self.sigma2 = sigma2
|
||||
|
|
|
|||
|
|
@ -3,6 +3,21 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
import GPy
|
||||
from GPy.core.parameterization.priors import (
|
||||
Gaussian,
|
||||
Uniform,
|
||||
LogGaussian,
|
||||
MultivariateGaussian,
|
||||
Gamma,
|
||||
InverseGamma,
|
||||
DGPLVM,
|
||||
DGPLVM_KFDA,
|
||||
DGPLVM_Lamda,
|
||||
DGPLVM_T,
|
||||
HalfT,
|
||||
Exponential,
|
||||
StudentT,
|
||||
)
|
||||
|
||||
|
||||
class TestPrior:
|
||||
|
|
@ -178,3 +193,90 @@ class TestPrior:
|
|||
# should raise an assertionerror.
|
||||
with pytest.raises(AssertionError):
|
||||
m.rbf.set_prior(gaussian)
|
||||
|
||||
|
||||
def initialize_gaussian_prior() -> None:
|
||||
return Gaussian(0, 1)
|
||||
|
||||
|
||||
def initialize_uniform_prior() -> None:
|
||||
return Uniform(0, 1)
|
||||
|
||||
|
||||
def initialize_log_gaussian_prior() -> None:
|
||||
return LogGaussian(0, 1)
|
||||
|
||||
|
||||
def initialize_multivariate_gaussian_prior() -> None:
|
||||
return MultivariateGaussian(np.zeros(2), np.eye(2))
|
||||
|
||||
|
||||
def initialize_gamma_prior() -> None:
|
||||
return Gamma(1, 1)
|
||||
|
||||
|
||||
def initialize_inverse_gamma_prior() -> None:
|
||||
return InverseGamma(1, 1)
|
||||
|
||||
|
||||
def initialize_dgplvm_prior() -> None:
|
||||
# return DGPLVM(...)
|
||||
raise NotImplementedError("No idea how to initialize this prior")
|
||||
|
||||
|
||||
def initialize_dgplvm_kfda_prior() -> None:
|
||||
# return DGPLVM_KFDA(...)
|
||||
raise NotImplementedError("No idea how to initialize this prior")
|
||||
|
||||
|
||||
def initialize_dgplvm_lamda_prior() -> None:
|
||||
# return DGPLVM_Lamda(...)
|
||||
raise NotImplementedError("No idea how to initialize this prior")
|
||||
|
||||
|
||||
def initialize_dgplvm_t_prior() -> None:
|
||||
# return DGPLVM_T(1, 1, (1, 1))
|
||||
raise NotImplementedError("No idea how to initialize this prior")
|
||||
|
||||
|
||||
def initialize_half_t_prior() -> None:
|
||||
return HalfT(1, 1)
|
||||
|
||||
|
||||
def initialize_exponential_prior() -> None:
|
||||
return Exponential(1)
|
||||
|
||||
|
||||
def initialize_student_t_prior() -> None:
|
||||
return StudentT(1, 1, 1)
|
||||
|
||||
|
||||
PRIORS = {
|
||||
"Gaussian": initialize_gaussian_prior,
|
||||
"Uniform": initialize_uniform_prior,
|
||||
"LogGaussian": initialize_log_gaussian_prior,
|
||||
"MultivariateGaussian": initialize_multivariate_gaussian_prior,
|
||||
"Gamma": initialize_gamma_prior,
|
||||
"InverseGamma": initialize_inverse_gamma_prior,
|
||||
# "DGPLVM": initialize_dgplvm_prior,
|
||||
# "DGPLVM_KFDA": initialize_dgplvm_kfda_prior,
|
||||
# "DGPLVM_Lamda": initialize_dgplvm_lamda_prior,
|
||||
# "DGPLVM_T": initialize_dgplvm_t_prior,
|
||||
"HalfT": initialize_half_t_prior,
|
||||
"Exponential": initialize_exponential_prior,
|
||||
"StudentT": initialize_student_t_prior,
|
||||
}
|
||||
|
||||
|
||||
def check_prior(prior_getter: str) -> None:
|
||||
prior_getter()
|
||||
|
||||
|
||||
def test_priors() -> None:
|
||||
for prior_name, prior_getter in PRIORS.items():
|
||||
try:
|
||||
check_prior(prior_getter)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to initialize {prior_name} prior"
|
||||
) from e # noqa E501
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue