mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-11 21:12:38 +02:00
Changed all M's for num_inducing
This commit is contained in:
parent
aac4f6a237
commit
3475b52b6c
21 changed files with 142 additions and 142 deletions
|
|
@ -223,11 +223,11 @@ class kern(parameterised):
|
|||
def dK_dtheta(self, dL_dK, X, X2=None):
|
||||
"""
|
||||
:param dL_dK: An array of dL_dK derivaties, dL_dK
|
||||
:type dL_dK: Np.ndarray (N x M)
|
||||
:type dL_dK: Np.ndarray (N x num_inducing)
|
||||
:param X: Observed data inputs
|
||||
:type X: np.ndarray (N x input_dim)
|
||||
:param X2: Observed dara inputs (optional, defaults to X)
|
||||
:type X2: np.ndarray (M x input_dim)
|
||||
:type X2: np.ndarray (num_inducing x input_dim)
|
||||
"""
|
||||
assert X.shape[1] == self.input_dim
|
||||
target = np.zeros(self.Nparam)
|
||||
|
|
@ -300,16 +300,16 @@ class kern(parameterised):
|
|||
return target
|
||||
|
||||
def dpsi1_dmuS(self, dL_dpsi1, Z, mu, S):
|
||||
"""return shapes are N,M,input_dim"""
|
||||
"""return shapes are N,num_inducing,input_dim"""
|
||||
target_mu, target_S = np.zeros((2, mu.shape[0], mu.shape[1]))
|
||||
[p.dpsi1_dmuS(dL_dpsi1, Z[:, i_s], mu[:, i_s], S[:, i_s], target_mu[:, i_s], target_S[:, i_s]) for p, i_s in zip(self.parts, self.input_slices)]
|
||||
return target_mu, target_S
|
||||
|
||||
def psi2(self, Z, mu, S):
|
||||
"""
|
||||
:param Z: np.ndarray of inducing inputs (M x input_dim)
|
||||
:param Z: np.ndarray of inducing inputs (num_inducing x input_dim)
|
||||
:param mu, S: np.ndarrays of means and variances (each N x input_dim)
|
||||
:returns psi2: np.ndarray (N,M,M)
|
||||
:returns psi2: np.ndarray (N,num_inducing,num_inducing)
|
||||
"""
|
||||
target = np.zeros((mu.shape[0], Z.shape[0], Z.shape[0]))
|
||||
[p.psi2(Z[:, i_s], mu[:, i_s], S[:, i_s], target) for p, i_s in zip(self.parts, self.input_slices)]
|
||||
|
|
@ -328,7 +328,7 @@ class kern(parameterised):
|
|||
|
||||
prod = np.multiply(tmp1, tmp2)
|
||||
crossterms += prod[:,:,None] + prod[:, None, :]
|
||||
|
||||
|
||||
target += crossterms
|
||||
return target
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue