Added (SLOW) Pure Python implementations of flat_to_triang and triang_to_flat

This commit is contained in:
Mike Croucher 2015-04-01 15:42:49 +01:00
parent 27c65003d2
commit 985b2ea70c
4 changed files with 48 additions and 12 deletions

View file

@ -170,7 +170,7 @@ class VarDTC_minibatch(LatentFunctionInference):
Kmm = kern.K(Z).copy() Kmm = kern.K(Z).copy()
diag.add(Kmm, self.const_jitter) diag.add(Kmm, self.const_jitter)
if not np.isfinite(Kmm).all(): if not np.isfinite(Kmm).all():
print Kmm print(Kmm)
Lm = jitchol(Kmm) Lm = jitchol(Kmm)
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right') LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')

View file

@ -26,11 +26,6 @@ class MappingGradChecker(GPy.core.Model):
self.mapping.update_gradients(self.dL_dY, self.X) self.mapping.update_gradients(self.dL_dY, self.X)
class MappingTests(unittest.TestCase): class MappingTests(unittest.TestCase):
def test_kernelmapping(self): def test_kernelmapping(self):
@ -68,5 +63,5 @@ class MappingTests(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
print "Running unit tests, please be (very) patient..." print("Running unit tests, please be (very) patient...")
unittest.main() unittest.main()

View file

@ -2,8 +2,13 @@
# Licensed under the GNU GPL version 3.0 # Licensed under the GNU GPL version 3.0
import numpy as np import numpy as np
from scipy import weave
from . import linalg from . import linalg
from .config import config
try:
from scipy import weave
except ImportError:
config.set('weave', 'working', 'False')
def safe_root(N): def safe_root(N):
i = np.sqrt(N) i = np.sqrt(N)
@ -12,12 +17,13 @@ def safe_root(N):
raise ValueError("N is not square!") raise ValueError("N is not square!")
return j return j
def flat_to_triang(flat): def _flat_to_triang_weave(flat):
"""take a matrix N x D and return a M X M x D array where """take a matrix N x D and return a M X M x D array where
N = M(M+1)/2 N = M(M+1)/2
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat. the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
This is the weave implementation
""" """
N, D = flat.shape N, D = flat.shape
M = (-1 + safe_root(8*N+1))/2 M = (-1 + safe_root(8*N+1))/2
@ -41,7 +47,24 @@ def flat_to_triang(flat):
weave.inline(code, ['flat', 'ret', 'D', 'M']) weave.inline(code, ['flat', 'ret', 'D', 'M'])
return ret return ret
def triang_to_flat(L): def _flat_to_triang_pure(flat_mat):
N, D = flat_mat.shape
M = (-1 + safe_root(8*N+1))//2
ret = np.zeros((M, M, D))
count = 0
for m in range(M):
for mm in range(m+1):
for d in range(D):
ret.flat[d + m*D*M + mm*D] = flat_mat.flat[count];
count = count+1
return ret
if config.getboolean('weave', 'working'):
flat_to_triang = _flat_to_triang_weave
else:
flat_to_triang = _flat_to_triang_pure
def _triang_to_flat_weave(L):
M, _, D = L.shape M, _, D = L.shape
L = np.ascontiguousarray(L) # should do nothing if L was created by flat_to_triang L = np.ascontiguousarray(L) # should do nothing if L was created by flat_to_triang
@ -65,6 +88,24 @@ def triang_to_flat(L):
weave.inline(code, ['flat', 'L', 'D', 'M']) weave.inline(code, ['flat', 'L', 'D', 'M'])
return flat return flat
def _triang_to_flat_pure(L):
M, _, D = L.shape
N = M*(M+1)//2
flat = np.empty((N, D))
count = 0;
for m in range(M):
for mm in range(m+1):
for d in range(D):
flat.flat[count] = L.flat[d + m*D*M + mm*D];
count = count +1
return flat
if config.getboolean('weave', 'working'):
triang_to_flat = _triang_to_flat_weave
else:
triang_to_flat = _triang_to_flat_pure
def triang_to_cov(L): def triang_to_cov(L):
return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])]) return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])])

View file

@ -102,14 +102,14 @@ def jitchol(A, maxtries=5):
num_tries = 1 num_tries = 1
while num_tries <= maxtries and np.isfinite(jitter): while num_tries <= maxtries and np.isfinite(jitter):
try: try:
print jitter print(jitter)
L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True) L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
return L return L
except: except:
jitter *= 10 jitter *= 10
finally: finally:
num_tries += 1 num_tries += 1
raise linalg.LinAlgError, "not positive definite, even with jitter." raise linalg.LinAlgError("not positive definite, even with jitter.")
import traceback import traceback
try: raise try: raise
except: except: