From 8213b5011cf2f4e8fdf373fc8f8576641113d84b Mon Sep 17 00:00:00 2001 From: mzwiessele Date: Fri, 27 Jun 2014 15:19:11 -0700 Subject: [PATCH] [gp] memory > only one copy --- GPy/core/gp.py | 20 ++++++++++--------- GPy/core/parameterization/observable_array.py | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/GPy/core/gp.py b/GPy/core/gp.py index c38820f3..2c26ee7d 100644 --- a/GPy/core/gp.py +++ b/GPy/core/gp.py @@ -12,6 +12,7 @@ from .. import likelihoods from ..likelihoods.gaussian import Gaussian from ..inference.latent_function_inference import exact_gaussian_inference, expectation_propagation, LatentFunctionInference from parameterization.variational import VariationalPosterior +from scipy.sparse.base import issparse class GP(Model): """ @@ -34,12 +35,13 @@ class GP(Model): assert X.ndim == 2 if isinstance(X, (ObsAr, VariationalPosterior)): self.X = X.copy() - else: self.X = ObsAr(X.copy()) + else: self.X = ObsAr(X) self.num_data, self.input_dim = self.X.shape assert Y.ndim == 2 - self.Y = ObsAr(Y.copy()) + if issparse(Y): self.Y = Y + else: self.Y = ObsAr(Y) assert Y.shape[0] == self.num_data _, self.output_dim = self.Y.shape @@ -199,9 +201,9 @@ class GP(Model): if fillcol is not None: kw['fillcol'] = fillcol return models_plots.plot_fit(self, plot_limits, which_data_rows, - which_data_ycols, fixed_inputs, - levels, samples, fignum, ax, resolution, - plot_raw=plot_raw, Y_metadata=Y_metadata, + which_data_ycols, fixed_inputs, + levels, samples, fignum, ax, resolution, + plot_raw=plot_raw, Y_metadata=Y_metadata, data_symbol=data_symbol, **kw) def plot(self, plot_limits=None, which_data_rows='all', @@ -250,9 +252,9 @@ class GP(Model): if fillcol is not None: kw['fillcol'] = fillcol return models_plots.plot_fit(self, plot_limits, which_data_rows, - which_data_ycols, fixed_inputs, - levels, samples, fignum, ax, resolution, - plot_raw=plot_raw, Y_metadata=Y_metadata, + which_data_ycols, fixed_inputs, + levels, samples, fignum, ax, resolution, + plot_raw=plot_raw, Y_metadata=Y_metadata, data_symbol=data_symbol, **kw) def input_sensitivity(self): @@ -281,4 +283,4 @@ class GP(Model): except KeyboardInterrupt: print "KeyboardInterrupt caught, calling on_optimization_end() to round things up" self.inference_method.on_optimization_end() - raise \ No newline at end of file + raise diff --git a/GPy/core/parameterization/observable_array.py b/GPy/core/parameterization/observable_array.py index 24fad7b6..09450b08 100644 --- a/GPy/core/parameterization/observable_array.py +++ b/GPy/core/parameterization/observable_array.py @@ -33,7 +33,7 @@ class ObsAr(np.ndarray, Pickleable, Observable): def _setup_observers(self): # do not setup anything, as observable arrays do not have default observers pass - + def copy(self): from lists_and_dicts import ObserverList memo = {}