mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-03 00:32:39 +02:00
[gp] memory > only one copy
This commit is contained in:
parent
06ee090a50
commit
8213b5011c
2 changed files with 12 additions and 10 deletions
|
|
@ -12,6 +12,7 @@ from .. import likelihoods
|
|||
from ..likelihoods.gaussian import Gaussian
|
||||
from ..inference.latent_function_inference import exact_gaussian_inference, expectation_propagation, LatentFunctionInference
|
||||
from parameterization.variational import VariationalPosterior
|
||||
from scipy.sparse.base import issparse
|
||||
|
||||
class GP(Model):
|
||||
"""
|
||||
|
|
@ -34,12 +35,13 @@ class GP(Model):
|
|||
assert X.ndim == 2
|
||||
if isinstance(X, (ObsAr, VariationalPosterior)):
|
||||
self.X = X.copy()
|
||||
else: self.X = ObsAr(X.copy())
|
||||
else: self.X = ObsAr(X)
|
||||
|
||||
self.num_data, self.input_dim = self.X.shape
|
||||
|
||||
assert Y.ndim == 2
|
||||
self.Y = ObsAr(Y.copy())
|
||||
if issparse(Y): self.Y = Y
|
||||
else: self.Y = ObsAr(Y)
|
||||
assert Y.shape[0] == self.num_data
|
||||
_, self.output_dim = self.Y.shape
|
||||
|
||||
|
|
@ -199,9 +201,9 @@ class GP(Model):
|
|||
if fillcol is not None:
|
||||
kw['fillcol'] = fillcol
|
||||
return models_plots.plot_fit(self, plot_limits, which_data_rows,
|
||||
which_data_ycols, fixed_inputs,
|
||||
levels, samples, fignum, ax, resolution,
|
||||
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
||||
which_data_ycols, fixed_inputs,
|
||||
levels, samples, fignum, ax, resolution,
|
||||
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
||||
data_symbol=data_symbol, **kw)
|
||||
|
||||
def plot(self, plot_limits=None, which_data_rows='all',
|
||||
|
|
@ -250,9 +252,9 @@ class GP(Model):
|
|||
if fillcol is not None:
|
||||
kw['fillcol'] = fillcol
|
||||
return models_plots.plot_fit(self, plot_limits, which_data_rows,
|
||||
which_data_ycols, fixed_inputs,
|
||||
levels, samples, fignum, ax, resolution,
|
||||
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
||||
which_data_ycols, fixed_inputs,
|
||||
levels, samples, fignum, ax, resolution,
|
||||
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
||||
data_symbol=data_symbol, **kw)
|
||||
|
||||
def input_sensitivity(self):
|
||||
|
|
@ -281,4 +283,4 @@ class GP(Model):
|
|||
except KeyboardInterrupt:
|
||||
print "KeyboardInterrupt caught, calling on_optimization_end() to round things up"
|
||||
self.inference_method.on_optimization_end()
|
||||
raise
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class ObsAr(np.ndarray, Pickleable, Observable):
|
|||
def _setup_observers(self):
|
||||
# do not setup anything, as observable arrays do not have default observers
|
||||
pass
|
||||
|
||||
|
||||
def copy(self):
|
||||
from lists_and_dicts import ObserverList
|
||||
memo = {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue