mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-18 13:55:14 +02:00
Sparse GP serialization
This commit is contained in:
parent
d85c9d5379
commit
7b2af57aee
4 changed files with 146 additions and 52 deletions
|
|
@ -7,6 +7,7 @@ from ..core import SparseGP
|
|||
from .. import likelihoods
|
||||
from .. import kern
|
||||
from ..inference.latent_function_inference import EPDTC
|
||||
from copy import deepcopy
|
||||
|
||||
class SparseGPClassification(SparseGP):
|
||||
"""
|
||||
|
|
@ -40,6 +41,27 @@ class SparseGPClassification(SparseGP):
|
|||
|
||||
SparseGP.__init__(self, X, Y, Z, kernel, likelihood, inference_method=EPDTC(), name='SparseGPClassification',Y_metadata=Y_metadata)
|
||||
|
||||
@staticmethod
|
||||
def from_sparse_gp(sparse_gp):
|
||||
from copy import deepcopy
|
||||
sparse_gp = deepcopy(sparse_gp)
|
||||
SparseGPClassification(sparse_gp.X, sparse_gp.Y, sparse_gp.Z, sparse_gp.kern, sparse_gp.likelihood, sparse_gp.inference_method, sparse_gp.mean_function, name='sparse_gp_classification')
|
||||
|
||||
def to_dict(self, save_data=True):
|
||||
model_dict = super(SparseGPClassification,self).to_dict(save_data)
|
||||
model_dict["class"] = "GPy.models.SparseGPClassification"
|
||||
return model_dict
|
||||
|
||||
@staticmethod
|
||||
def from_dict(input_dict, data=None):
|
||||
import GPy
|
||||
m = GPy.core.model.Model.from_dict(input_dict, data)
|
||||
return GPClassification.from_sparse_gp(m)
|
||||
|
||||
def save_model(self, output_filename, compress=True, save_data=True):
|
||||
self._save_model(output_filename, compress=True, save_data=True)
|
||||
|
||||
|
||||
class SparseGPClassificationUncertainInput(SparseGP):
|
||||
"""
|
||||
Sparse Gaussian Process model for classification with uncertain inputs.
|
||||
|
|
@ -87,8 +109,3 @@ class SparseGPClassificationUncertainInput(SparseGP):
|
|||
self.psi2 = self.kern.psi2n(self.Z, self.X)
|
||||
self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.Z, self.likelihood, self.Y, self.Y_metadata, psi0=self.psi0, psi1=self.psi1, psi2=self.psi2)
|
||||
self._update_gradients()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue