mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-04-26 21:36:23 +02:00
add MRD for regression benchmark
This commit is contained in:
parent
8f78431983
commit
9564287cf3
2 changed files with 31 additions and 3 deletions
|
|
@ -54,13 +54,41 @@ class GP_RBF(RegressionMethod):
|
|||
|
||||
def _fit(self, train_data):
|
||||
inputs, labels = train_data
|
||||
self.model = GPy.models.GPRegression(inputs, labels,kernel=GPy.kern.RBF(inputs.shape[-1],ARD=True) +GPy.kern.Linear(inputs.shape[1], ARD=True) + GPy.kern.White(inputs.shape[1],0.01) )
|
||||
self.model = GPy.models.GPRegression(inputs, labels,kernel=GPy.kern.RBF(inputs.shape[-1],ARD=True) +GPy.kern.Linear(inputs.shape[1], ARD=True) )
|
||||
self.model.likelihood.variance[:] = labels.var()*0.01
|
||||
self.model.optimize()
|
||||
return True
|
||||
|
||||
def _predict(self, test_data):
|
||||
return self.model.predict(test_data)[0]
|
||||
|
||||
class SparseGP_RBF(RegressionMethod):
|
||||
name = 'SparseGP_RBF'
|
||||
|
||||
def _fit(self, train_data):
|
||||
inputs, labels = train_data
|
||||
self.model = GPy.models.SparseGPRegression(inputs, labels,kernel=GPy.kern.RBF(inputs.shape[-1],ARD=True) +GPy.kern.Linear(inputs.shape[1], ARD=True) ,num_inducing=100)
|
||||
self.model.likelihood.variance[:] = labels.var()*0.01
|
||||
self.model.optimize()
|
||||
return True
|
||||
|
||||
def _predict(self, test_data):
|
||||
return self.model.predict(test_data)[0]
|
||||
|
||||
# class MRD_RBF(RegressionMethod):
|
||||
# name = 'MRD_RBF'
|
||||
#
|
||||
# def _fit(self, train_data):
|
||||
# inputs, labels = train_data
|
||||
# Q = 5
|
||||
# self.model = GPy.models.MRD([inputs, labels],Q,kernel=GPy.kern.RBF(Q,ARD=True),num_inducing=50)
|
||||
# self.model.Y0.likelihood.variance[:] = inputs.var()*0.01
|
||||
# self.model.Y1.likelihood.variance[:] = labels.var()*0.01
|
||||
# self.model.optimize()
|
||||
# return True
|
||||
#
|
||||
# def _predict(self, test_data):
|
||||
# return self.model.predict(self.model.Y0.infer_newX(test_data)[0])[0]
|
||||
|
||||
class SVIGP_RBF(RegressionMethod):
|
||||
name = 'SVIGP_RBF'
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue