mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 03:52:39 +02:00
Added (SLOW) Pure Python implementations of flat_to_triang and triang_to_flat
This commit is contained in:
parent
27c65003d2
commit
985b2ea70c
4 changed files with 48 additions and 12 deletions
|
|
@ -170,7 +170,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
Kmm = kern.K(Z).copy()
|
||||
diag.add(Kmm, self.const_jitter)
|
||||
if not np.isfinite(Kmm).all():
|
||||
print Kmm
|
||||
print(Kmm)
|
||||
Lm = jitchol(Kmm)
|
||||
|
||||
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
|
||||
|
|
|
|||
|
|
@ -26,11 +26,6 @@ class MappingGradChecker(GPy.core.Model):
|
|||
self.mapping.update_gradients(self.dL_dY, self.X)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class MappingTests(unittest.TestCase):
|
||||
|
||||
def test_kernelmapping(self):
|
||||
|
|
@ -68,5 +63,5 @@ class MappingTests(unittest.TestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print "Running unit tests, please be (very) patient..."
|
||||
print("Running unit tests, please be (very) patient...")
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -2,8 +2,13 @@
|
|||
# Licensed under the GNU GPL version 3.0
|
||||
|
||||
import numpy as np
|
||||
from scipy import weave
|
||||
from . import linalg
|
||||
from .config import config
|
||||
|
||||
try:
|
||||
from scipy import weave
|
||||
except ImportError:
|
||||
config.set('weave', 'working', 'False')
|
||||
|
||||
def safe_root(N):
|
||||
i = np.sqrt(N)
|
||||
|
|
@ -12,12 +17,13 @@ def safe_root(N):
|
|||
raise ValueError("N is not square!")
|
||||
return j
|
||||
|
||||
def flat_to_triang(flat):
|
||||
def _flat_to_triang_weave(flat):
|
||||
"""take a matrix N x D and return a M X M x D array where
|
||||
|
||||
N = M(M+1)/2
|
||||
|
||||
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
|
||||
This is the weave implementation
|
||||
"""
|
||||
N, D = flat.shape
|
||||
M = (-1 + safe_root(8*N+1))/2
|
||||
|
|
@ -41,7 +47,24 @@ def flat_to_triang(flat):
|
|||
weave.inline(code, ['flat', 'ret', 'D', 'M'])
|
||||
return ret
|
||||
|
||||
def triang_to_flat(L):
|
||||
def _flat_to_triang_pure(flat_mat):
|
||||
N, D = flat_mat.shape
|
||||
M = (-1 + safe_root(8*N+1))//2
|
||||
ret = np.zeros((M, M, D))
|
||||
count = 0
|
||||
for m in range(M):
|
||||
for mm in range(m+1):
|
||||
for d in range(D):
|
||||
ret.flat[d + m*D*M + mm*D] = flat_mat.flat[count];
|
||||
count = count+1
|
||||
return ret
|
||||
|
||||
if config.getboolean('weave', 'working'):
|
||||
flat_to_triang = _flat_to_triang_weave
|
||||
else:
|
||||
flat_to_triang = _flat_to_triang_pure
|
||||
|
||||
def _triang_to_flat_weave(L):
|
||||
M, _, D = L.shape
|
||||
|
||||
L = np.ascontiguousarray(L) # should do nothing if L was created by flat_to_triang
|
||||
|
|
@ -65,6 +88,24 @@ def triang_to_flat(L):
|
|||
weave.inline(code, ['flat', 'L', 'D', 'M'])
|
||||
return flat
|
||||
|
||||
def _triang_to_flat_pure(L):
|
||||
M, _, D = L.shape
|
||||
|
||||
N = M*(M+1)//2
|
||||
flat = np.empty((N, D))
|
||||
count = 0;
|
||||
for m in range(M):
|
||||
for mm in range(m+1):
|
||||
for d in range(D):
|
||||
flat.flat[count] = L.flat[d + m*D*M + mm*D];
|
||||
count = count +1
|
||||
return flat
|
||||
|
||||
if config.getboolean('weave', 'working'):
|
||||
triang_to_flat = _triang_to_flat_weave
|
||||
else:
|
||||
triang_to_flat = _triang_to_flat_pure
|
||||
|
||||
def triang_to_cov(L):
|
||||
return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])])
|
||||
|
||||
|
|
|
|||
|
|
@ -102,14 +102,14 @@ def jitchol(A, maxtries=5):
|
|||
num_tries = 1
|
||||
while num_tries <= maxtries and np.isfinite(jitter):
|
||||
try:
|
||||
print jitter
|
||||
print(jitter)
|
||||
L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
|
||||
return L
|
||||
except:
|
||||
jitter *= 10
|
||||
finally:
|
||||
num_tries += 1
|
||||
raise linalg.LinAlgError, "not positive definite, even with jitter."
|
||||
raise linalg.LinAlgError("not positive definite, even with jitter.")
|
||||
import traceback
|
||||
try: raise
|
||||
except:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue