Merge pull request #783 from MashaNaslidnyk/num-data-fix

Update self.num_data in GP when X is updated
This commit is contained in:
Zhenwen Dai 2020-03-13 10:08:05 +00:00 committed by GitHub
commit 1f9ac259ca
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 19 deletions

View file

@ -43,8 +43,6 @@ class GP(Model):
self.X = X.copy()
else: self.X = ObsAr(X)
self.num_data, self.input_dim = self.X.shape
assert Y.ndim == 2
logger.info("initializing Y")
@ -199,6 +197,14 @@ class GP(Model):
def _predictive_variable(self):
return self.X
@property
def num_data(self):
return self.X.shape[0]
@property
def input_dim(self):
return self.X.shape[1]
def set_XY(self, X=None, Y=None):
"""
Set the input / output data of the model
@ -235,6 +241,7 @@ class GP(Model):
self.link_parameter(self.X, index=index)
else:
self.X = ObsAr(X)
self.update_model(True)
def set_X(self,X):
@ -596,9 +603,9 @@ class GP(Model):
:param size: the number of a posteriori samples.
:type size: int.
:returns: set of simulations
:rtype: np.ndarray (Nnew x D x samples)
:rtype: np.ndarray (Nnew x D x samples)
"""
predict_kwargs["full_cov"] = True # Always use the full covariance for posterior samples.
predict_kwargs["full_cov"] = True # Always use the full covariance for posterior samples.
m, v = self._raw_predict(X, **predict_kwargs)
if self.normalizer is not None:
m, v = self.normalizer.inverse_mean(m), self.normalizer.inverse_variance(v)