mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-24 14:15:14 +02:00
Changes to sympy covariance.
This commit is contained in:
parent
d0c563ff0a
commit
d1b6d18ddf
1 changed files with 24 additions and 10 deletions
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
import sympy as sp
|
||||
from sympy.utilities.codegen import codegen
|
||||
from sympy.core.cache import clear_cache
|
||||
|
||||
from scipy import weave
|
||||
import re
|
||||
import os
|
||||
|
|
@ -28,39 +29,47 @@ class spkern(Kernpart):
|
|||
- to handle multpile correlated outputs, you'll need to add parameters with an index, such as lengthscale_i and lengthscale_j.
|
||||
"""
|
||||
def __init__(self, input_dim, k=None, output_dim=1, name=None, param=None):
|
||||
|
||||
if name is None:
|
||||
self.name='sympykern'
|
||||
else:
|
||||
self.name = name
|
||||
name='sympykern'
|
||||
if k is None:
|
||||
raise ValueError, "You must provide an argument for the covariance function."
|
||||
super(spkern, self).__init__(input_dim, name)
|
||||
|
||||
self._sp_k = k
|
||||
|
||||
# pull the variable names out of the symbolic covariance function.
|
||||
sp_vars = [e for e in k.atoms() if e.is_Symbol]
|
||||
self._sp_x= sorted([e for e in sp_vars if e.name[0:2]=='x_'],key=lambda x:int(x.name[2:]))
|
||||
self._sp_z= sorted([e for e in sp_vars if e.name[0:2]=='z_'],key=lambda z:int(z.name[2:]))
|
||||
|
||||
# Check that variable names make sense.
|
||||
assert all([x.name=='x_%i'%i for i,x in enumerate(self._sp_x)])
|
||||
assert all([z.name=='z_%i'%i for i,z in enumerate(self._sp_z)])
|
||||
assert len(self._sp_x)==len(self._sp_z)
|
||||
self.input_dim = len(self._sp_x)
|
||||
assert len(self._sp_x)==input_dim
|
||||
|
||||
# If it is a multi-output covariance, add an input for indexing the outputs.
|
||||
self._real_input_dim = self.input_dim
|
||||
if output_dim > 1:
|
||||
self.input_dim += 1
|
||||
assert self.input_dim == input_dim
|
||||
self.output_dim = output_dim
|
||||
# extract parameter names
|
||||
|
||||
# extract parameter names from the covariance
|
||||
thetas = sorted([e for e in sp_vars if not (e.name[0:2]=='x_' or e.name[0:2]=='z_')],key=lambda e:e.name)
|
||||
|
||||
|
||||
# Look for parameters with index.
|
||||
# Look for parameters with index (subscripts), they are associated with different outputs.
|
||||
if self.output_dim>1:
|
||||
self._sp_theta_i = sorted([e for e in thetas if (e.name[-2:]=='_i')], key=lambda e:e.name)
|
||||
self._sp_theta_j = sorted([e for e in thetas if (e.name[-2:]=='_j')], key=lambda e:e.name)
|
||||
|
||||
# Make sure parameter appears with both indices!
|
||||
assert len(self._sp_theta_i)==len(self._sp_theta_j)
|
||||
assert all([theta_i.name[:-2]==theta_j.name[:-2] for theta_i, theta_j in zip(self._sp_theta_i, self._sp_theta_j)])
|
||||
|
||||
# Extract names of shared parameters
|
||||
# Extract names of shared parameters (those without a subscript)
|
||||
self._sp_theta = [theta for theta in thetas if theta not in self._sp_theta_i and theta not in self._sp_theta_j]
|
||||
|
||||
self.num_split_params = len(self._sp_theta_i)
|
||||
|
|
@ -77,7 +86,8 @@ class spkern(Kernpart):
|
|||
self._sp_theta = thetas
|
||||
self.num_shared_params = len(self._sp_theta)
|
||||
self.num_params = self.num_shared_params
|
||||
|
||||
|
||||
# Add parameters to the model.
|
||||
for theta in self._sp_theta:
|
||||
val = 1.0
|
||||
if param is not None:
|
||||
|
|
@ -87,18 +97,22 @@ class spkern(Kernpart):
|
|||
#deal with param
|
||||
self._set_params(self._get_params())
|
||||
|
||||
#Differentiate!
|
||||
# Differentiate with respect to parameters.
|
||||
self._sp_dk_dtheta = [sp.diff(k,theta).simplify() for theta in self._sp_theta]
|
||||
if self.output_dim > 1:
|
||||
self._sp_dk_dtheta_i = [sp.diff(k,theta).simplify() for theta in self._sp_theta_i]
|
||||
|
||||
|
||||
# differentiate with respect to input variables.
|
||||
self._sp_dk_dx = [sp.diff(k,xi).simplify() for xi in self._sp_x]
|
||||
|
||||
# psi_stats aren't yet implemented.
|
||||
if False:
|
||||
self.compute_psi_stats()
|
||||
|
||||
# generate the code for the covariance functions
|
||||
self._gen_code()
|
||||
|
||||
if weave
|
||||
if False:
|
||||
extra_compile_args = ['-ftree-vectorize', '-mssse3', '-ftree-vectorizer-verbose=5']
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue