From 6a1de2bfc2dccd30c20a0bb30902a283eca1b6d1 Mon Sep 17 00:00:00 2001 From: Alan Saul Date: Wed, 12 Feb 2014 10:39:15 +0000 Subject: [PATCH] Added block matrix dot product --- GPy/util/block_matrices.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/GPy/util/block_matrices.py b/GPy/util/block_matrices.py index cdbb1b0b..464e3ba1 100644 --- a/GPy/util/block_matrices.py +++ b/GPy/util/block_matrices.py @@ -35,6 +35,37 @@ def unblock(B): count_i += i return A +def block_dot(A, B): + """ + Element wise dot product on block matricies + + +------+------+ +------+------+ +-------+-------+ + | | | | | | |A11.B11|B12.B12| + | A11 | A12 | | B11 | B12 | | | | + +------+------+ o +------+------| = +-------+-------+ + | | | | | | |A21.B21|A22.B22| + | A21 | A22 | | B21 | B22 | | | | + +-------------+ +------+------+ +-------+-------+ + + ..Note + If either (A or B) of the diagonal matrices are stored as vectors then a more + efficient dot product using numpy broadcasting will be used, i.e. A11*B11 + """ + #Must have same number of blocks and be a block matrix + assert A.dtype is np.dtype('object'), "Must be a block matrix" + assert B.dtype is np.dtype('object'), "Must be a block matrix" + Ashape = A.shape + Bshape = B.shape + assert Ashape == Bshape + def f(A,B): + if Ashape[0] == Ashape[1] or Bshape[0] == Bshape[1]: + #FIXME: Careful if one is transpose of other, would make a matrix + return A*B + else: + return np.dot(A,B) + dot = np.vectorize(f, otypes = [np.object]) + return dot(A,B) + if __name__=='__main__': A = np.zeros((5,5)) @@ -43,6 +74,3 @@ if __name__=='__main__': print B assert np.all(unblock(B) == A) - - import ipdb; ipdb.set_trace() # XXX BREAKPOINT -