mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-05 01:32:40 +02:00
Shape of heteroscedastic variance corrected
This commit is contained in:
parent
06f540584b
commit
2745904c01
1 changed files with 11 additions and 6 deletions
|
|
@ -335,20 +335,25 @@ class HeteroscedasticGaussian(Gaussian):
|
||||||
print("Warning, Exact inference is not implemeted for non-identity link functions,\
|
print("Warning, Exact inference is not implemeted for non-identity link functions,\
|
||||||
if you are not already, ensure Laplace inference_method is used")
|
if you are not already, ensure Laplace inference_method is used")
|
||||||
|
|
||||||
super(HeteroscedasticGaussian, self).__init__(gp_link, np.ones(Y_metadata['output_index'].shape[0])*variance, name)
|
super(HeteroscedasticGaussian, self).__init__(gp_link, np.ones(Y_metadata['output_index'].shape)*variance, name)
|
||||||
|
|
||||||
def exact_inference_gradients(self, dL_dKdiag,Y_metadata=None):
|
def exact_inference_gradients(self, dL_dKdiag,Y_metadata=None):
|
||||||
return dL_dKdiag[Y_metadata['output_index']][:,0]
|
return dL_dKdiag[Y_metadata['output_index']]
|
||||||
|
|
||||||
def gaussian_variance(self, Y_metadata=None):
|
def gaussian_variance(self, Y_metadata=None):
|
||||||
return self.variance[Y_metadata['output_index']]
|
return self.variance[Y_metadata['output_index'].flatten()]
|
||||||
|
|
||||||
def predictive_values(self, mu, var, full_cov=False, Y_metadata=None):
|
def predictive_values(self, mu, var, full_cov=False, Y_metadata=None):
|
||||||
|
_s = self.variance[Y_metadata['output_index'].flatten()]
|
||||||
if full_cov:
|
if full_cov:
|
||||||
if var.ndim == 2:
|
if var.ndim == 2:
|
||||||
var += np.eye(var.shape[0])*self.variance
|
var += np.eye(var.shape[0])*_s
|
||||||
if var.ndim == 3:
|
if var.ndim == 3:
|
||||||
var += np.atleast_3d(np.eye(var.shape[0])*self.variance)
|
var += np.atleast_3d(np.eye(var.shape[0])*_s)
|
||||||
else:
|
else:
|
||||||
var += self.variance
|
var += _s
|
||||||
return mu, var
|
return mu, var
|
||||||
|
|
||||||
|
def predictive_quantiles(self, mu, var, quantiles, Y_metadata=None):
|
||||||
|
_s = self.variance[Y_metadata['output_index'].flatten()]
|
||||||
|
return [stats.norm.ppf(q/100.)*np.sqrt(var + _s) + mu for q in quantiles]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue