diff --git a/GPy/__init__.py b/GPy/__init__.py index ecc3cf3c..392cefd3 100644 --- a/GPy/__init__.py +++ b/GPy/__init__.py @@ -39,18 +39,28 @@ def load(file_or_path): :param file_name: path/to/file.pickle """ + # This is the pickling pain when changing _src -> src try: - import cPickle as pickle - if isinstance(file_or_path, basestring): - with open(file_or_path, 'rb') as f: - m = pickle.load(f) - else: - m = pickle.load(file_or_path) - except: - import pickle - if isinstance(file_or_path, str): - with open(file_or_path, 'rb') as f: - m = pickle.load(f) - else: - m = pickle.load(file_or_path) + try: + import cPickle as pickle + if isinstance(file_or_path, basestring): + with open(file_or_path, 'rb') as f: + m = pickle.load(f) + else: + m = pickle.load(file_or_path) + except: + import pickle + if isinstance(file_or_path, str): + with open(file_or_path, 'rb') as f: + m = pickle.load(f) + else: + m = pickle.load(file_or_path) + except ImportError: + import sys + import inspect + sys.modules['GPy.kern._src'] = kern.src + for name, module in inspect.getmembers(kern.src): + if not name.startswith('_'): + sys.modules['GPy.kern._src.{}'.format(name)] = module + m = load(file_or_path) return m diff --git a/GPy/plotting/__init__.py b/GPy/plotting/__init__.py index 28c05cef..e4fe7080 100644 --- a/GPy/plotting/__init__.py +++ b/GPy/plotting/__init__.py @@ -13,6 +13,7 @@ def change_plotting_library(lib): if lib == 'matplotlib': import matplotlib from .matplot_dep.plot_definitions import MatplotlibPlots + from .matplot_dep import visualize, mapping_plots, priors_plots, ssgplvm, svig_plots, variational_plots, img_plots current_lib[0] = MatplotlibPlots() if lib == 'plotly': import plotly @@ -22,10 +23,11 @@ def change_plotting_library(lib): current_lib[0] = None #=========================================================================== except (ImportError, NameError): + raise config.set('plotting', 'library', 'none') import warnings warnings.warn(ImportWarning("{} not available, install newest version of {} for plotting".format(lib, lib))) - + from ..util.config import config lib = config.get('plotting', 'library') change_plotting_library(lib) diff --git a/GPy/plotting/gpy_plot/plot_util.py b/GPy/plotting/gpy_plot/plot_util.py index 5a884b8e..715ca759 100644 --- a/GPy/plotting/gpy_plot/plot_util.py +++ b/GPy/plotting/gpy_plot/plot_util.py @@ -275,7 +275,7 @@ def get_x_y_var(model): and Y the outputs If (X, X_variance, Y) is given, this just returns. - + :returns: (X, X_variance, Y) """ # model given @@ -285,7 +285,10 @@ def get_x_y_var(model): else: X = model.X.values X_variance = None - Y = model.Y.values + try: + Y = model.Y.values + except AttributeError: + Y = model.Y if sparse.issparse(Y): Y = Y.todense().view(np.ndarray) return X, X_variance, Y diff --git a/GPy/plotting/matplot_dep/mapping_plots.py b/GPy/plotting/matplot_dep/mapping_plots.py index c563d392..8ec09758 100644 --- a/GPy/plotting/matplot_dep/mapping_plots.py +++ b/GPy/plotting/matplot_dep/mapping_plots.py @@ -7,7 +7,6 @@ try: from matplotlib import pyplot as pb except: pass -from .base_plots import x_frame1D, x_frame2D def plot_mapping(self, plot_limits=None, which_data='all', which_parts='all', resolution=None, levels=20, samples=0, fignum=None, ax=None, fixed_inputs=[], linecol=Tango.colorsHex['darkBlue']): @@ -52,6 +51,7 @@ def plot_mapping(self, plot_limits=None, which_data='all', which_parts='all', re ax = fig.add_subplot(111) plotdims = self.input_dim - len(fixed_inputs) + from ..gpy_plot.plot_util import x_frame1D, x_frame2D if plotdims == 1: diff --git a/GPy/plotting/matplot_dep/variational_plots.py b/GPy/plotting/matplot_dep/variational_plots.py index 24e613aa..34681552 100644 --- a/GPy/plotting/matplot_dep/variational_plots.py +++ b/GPy/plotting/matplot_dep/variational_plots.py @@ -94,7 +94,7 @@ def plot_SpikeSlab(parameterized, fignum=None, ax=None, colors=None, side_by_sid a.set_xticklabels('') # binary prob plot a = fig.add_subplot(*sub2) - a.bar(x,gamma[:,i],bottom=0.,linewidth=0,width=1.0,align='center') + a.bar(x,gamma[:,i],bottom=0.,linewidth=1.,width=1.0,align='center') a.set_xlim(x.min(), x.max()) a.set_ylim([0.,1.]) pb.draw() diff --git a/GPy/testing/pickle_test.pickle b/GPy/testing/pickle_test.pickle new file mode 100644 index 00000000..568e7a60 Binary files /dev/null and b/GPy/testing/pickle_test.pickle differ diff --git a/GPy/testing/pickle_tests.py b/GPy/testing/pickle_tests.py index 40f690c4..6836f75f 100644 --- a/GPy/testing/pickle_tests.py +++ b/GPy/testing/pickle_tests.py @@ -30,6 +30,12 @@ class ListDictTestCase(unittest.TestCase): np.testing.assert_array_equal(a1, a2) class Test(ListDictTestCase): + def test_load_pickle(self): + import os, GPy + m = GPy.load(os.path.join(os.path.abspath(os.path.split(__file__)[0]), 'pickle_test.pickle')) + self.assertTrue(m.checkgrad()) + self.assertEqual(m.log_likelihood(), -4.7351019830022087) + def test_model(self): par = toy_model() pcopy = par.copy() diff --git a/MANIFEST.in b/MANIFEST.in index 1800fa52..91f053cd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,11 +1,7 @@ -include *.txt -recursive-include doc *.txt -include *.md -recursive-include doc *.md -include *.cfg recursive-include doc *.cfg include *.json -recursive-include doc *.json recursive-include GPy *.c recursive-include GPy *.so recursive-include GPy *.pyx +include GPy/testing/plotting_tests/baseline/*.png +include GPy/testing/pickle_test.pickle diff --git a/setup.py b/setup.py index 0ba19167..1ad5e2fd 100644 --- a/setup.py +++ b/setup.py @@ -7,21 +7,21 @@ # Copyright (c) 2015, Max Zwiessele # # All rights reserved. -# +# # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: -# +# # * Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. -# +# # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. -# +# # * Neither the name of GPy nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. -# +# # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -44,7 +44,7 @@ import codecs def read(fname): with codecs.open(fname, 'r', 'latin') as f: return f.read() - + def read_to_rst(fname): try: import pypandoc @@ -102,32 +102,35 @@ setup(name = 'GPy', ext_modules = ext_mods, packages = ["GPy", "GPy.core", - "GPy.core.parameterization", + "GPy.core.parameterization", "GPy.kern", "GPy.kern.src", - "GPy.kern.src.psi_comp", + "GPy.kern.src.psi_comp", "GPy.models", "GPy.inference", "GPy.inference.optimization", "GPy.inference.mcmc", "GPy.inference.latent_function_inference", - "GPy.likelihoods", + "GPy.likelihoods", "GPy.mappings", "GPy.examples", "GPy.testing", - "GPy.util", + "GPy.util", "GPy.plotting", "GPy.plotting.gpy_plot", - "GPy.plotting.matplot_dep", + "GPy.plotting.matplot_dep", "GPy.plotting.matplot_dep.controllers", - "GPy.plotting.plotly_dep", + "GPy.plotting.plotly_dep", ], package_dir={'GPy': 'GPy'}, package_data = {'GPy': ['defaults.cfg', 'installation.cfg', 'util/data_resources.json', 'util/football_teams.json', - 'plotting/plotting_tests/baseline/*.png' + 'testing/plotting_tests/baseline/*.png' ]}, + data_files=[('GPy/testing/plotting_tests/baseline', 'testing/plotting_tests/baseline/*.png'), + ('GPy/testing/', 'GPy/testing/pickle_test.pickle'), + ], include_package_data = True, py_modules = ['GPy.__init__'], test_suite = 'GPy.testing', @@ -170,7 +173,7 @@ if not os.path.exists(user_file): if os.path.exists(old_user_file): # Move it to new location: print("GPy: Found old config file, moving to new location {}".format(user_file)) - os.rename(old_user_file, user_file) + os.rename(old_user_file, user_file) else: # No config file exists, save informative stub to user config folder: print("GPy: Saving user configuration file to {}".format(user_file))