[gp] memory > only one copy

This commit is contained in:
mzwiessele 2014-06-27 15:19:11 -07:00
parent 06ee090a50
commit 8213b5011c
2 changed files with 12 additions and 10 deletions

View file

@ -12,6 +12,7 @@ from .. import likelihoods
from ..likelihoods.gaussian import Gaussian from ..likelihoods.gaussian import Gaussian
from ..inference.latent_function_inference import exact_gaussian_inference, expectation_propagation, LatentFunctionInference from ..inference.latent_function_inference import exact_gaussian_inference, expectation_propagation, LatentFunctionInference
from parameterization.variational import VariationalPosterior from parameterization.variational import VariationalPosterior
from scipy.sparse.base import issparse
class GP(Model): class GP(Model):
""" """
@ -34,12 +35,13 @@ class GP(Model):
assert X.ndim == 2 assert X.ndim == 2
if isinstance(X, (ObsAr, VariationalPosterior)): if isinstance(X, (ObsAr, VariationalPosterior)):
self.X = X.copy() self.X = X.copy()
else: self.X = ObsAr(X.copy()) else: self.X = ObsAr(X)
self.num_data, self.input_dim = self.X.shape self.num_data, self.input_dim = self.X.shape
assert Y.ndim == 2 assert Y.ndim == 2
self.Y = ObsAr(Y.copy()) if issparse(Y): self.Y = Y
else: self.Y = ObsAr(Y)
assert Y.shape[0] == self.num_data assert Y.shape[0] == self.num_data
_, self.output_dim = self.Y.shape _, self.output_dim = self.Y.shape
@ -199,9 +201,9 @@ class GP(Model):
if fillcol is not None: if fillcol is not None:
kw['fillcol'] = fillcol kw['fillcol'] = fillcol
return models_plots.plot_fit(self, plot_limits, which_data_rows, return models_plots.plot_fit(self, plot_limits, which_data_rows,
which_data_ycols, fixed_inputs, which_data_ycols, fixed_inputs,
levels, samples, fignum, ax, resolution, levels, samples, fignum, ax, resolution,
plot_raw=plot_raw, Y_metadata=Y_metadata, plot_raw=plot_raw, Y_metadata=Y_metadata,
data_symbol=data_symbol, **kw) data_symbol=data_symbol, **kw)
def plot(self, plot_limits=None, which_data_rows='all', def plot(self, plot_limits=None, which_data_rows='all',
@ -250,9 +252,9 @@ class GP(Model):
if fillcol is not None: if fillcol is not None:
kw['fillcol'] = fillcol kw['fillcol'] = fillcol
return models_plots.plot_fit(self, plot_limits, which_data_rows, return models_plots.plot_fit(self, plot_limits, which_data_rows,
which_data_ycols, fixed_inputs, which_data_ycols, fixed_inputs,
levels, samples, fignum, ax, resolution, levels, samples, fignum, ax, resolution,
plot_raw=plot_raw, Y_metadata=Y_metadata, plot_raw=plot_raw, Y_metadata=Y_metadata,
data_symbol=data_symbol, **kw) data_symbol=data_symbol, **kw)
def input_sensitivity(self): def input_sensitivity(self):
@ -281,4 +283,4 @@ class GP(Model):
except KeyboardInterrupt: except KeyboardInterrupt:
print "KeyboardInterrupt caught, calling on_optimization_end() to round things up" print "KeyboardInterrupt caught, calling on_optimization_end() to round things up"
self.inference_method.on_optimization_end() self.inference_method.on_optimization_end()
raise raise

View file

@ -33,7 +33,7 @@ class ObsAr(np.ndarray, Pickleable, Observable):
def _setup_observers(self): def _setup_observers(self):
# do not setup anything, as observable arrays do not have default observers # do not setup anything, as observable arrays do not have default observers
pass pass
def copy(self): def copy(self):
from lists_and_dicts import ObserverList from lists_and_dicts import ObserverList
memo = {} memo = {}