Added block matrix dot product

This commit is contained in:
Alan Saul 2014-02-12 10:39:15 +00:00
parent 932b5468ae
commit 6a1de2bfc2

View file

@ -35,6 +35,37 @@ def unblock(B):
count_i += i
return A
def block_dot(A, B):
"""
Element wise dot product on block matricies
+------+------+ +------+------+ +-------+-------+
| | | | | | |A11.B11|B12.B12|
| A11 | A12 | | B11 | B12 | | | |
+------+------+ o +------+------| = +-------+-------+
| | | | | | |A21.B21|A22.B22|
| A21 | A22 | | B21 | B22 | | | |
+-------------+ +------+------+ +-------+-------+
..Note
If either (A or B) of the diagonal matrices are stored as vectors then a more
efficient dot product using numpy broadcasting will be used, i.e. A11*B11
"""
#Must have same number of blocks and be a block matrix
assert A.dtype is np.dtype('object'), "Must be a block matrix"
assert B.dtype is np.dtype('object'), "Must be a block matrix"
Ashape = A.shape
Bshape = B.shape
assert Ashape == Bshape
def f(A,B):
if Ashape[0] == Ashape[1] or Bshape[0] == Bshape[1]:
#FIXME: Careful if one is transpose of other, would make a matrix
return A*B
else:
return np.dot(A,B)
dot = np.vectorize(f, otypes = [np.object])
return dot(A,B)
if __name__=='__main__':
A = np.zeros((5,5))
@ -43,6 +74,3 @@ if __name__=='__main__':
print B
assert np.all(unblock(B) == A)
import ipdb; ipdb.set_trace() # XXX BREAKPOINT