mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-04-30 15:26:23 +02:00
Added block matrix dot product
This commit is contained in:
parent
932b5468ae
commit
6a1de2bfc2
1 changed files with 31 additions and 3 deletions
|
|
@ -35,6 +35,37 @@ def unblock(B):
|
|||
count_i += i
|
||||
return A
|
||||
|
||||
def block_dot(A, B):
|
||||
"""
|
||||
Element wise dot product on block matricies
|
||||
|
||||
+------+------+ +------+------+ +-------+-------+
|
||||
| | | | | | |A11.B11|B12.B12|
|
||||
| A11 | A12 | | B11 | B12 | | | |
|
||||
+------+------+ o +------+------| = +-------+-------+
|
||||
| | | | | | |A21.B21|A22.B22|
|
||||
| A21 | A22 | | B21 | B22 | | | |
|
||||
+-------------+ +------+------+ +-------+-------+
|
||||
|
||||
..Note
|
||||
If either (A or B) of the diagonal matrices are stored as vectors then a more
|
||||
efficient dot product using numpy broadcasting will be used, i.e. A11*B11
|
||||
"""
|
||||
#Must have same number of blocks and be a block matrix
|
||||
assert A.dtype is np.dtype('object'), "Must be a block matrix"
|
||||
assert B.dtype is np.dtype('object'), "Must be a block matrix"
|
||||
Ashape = A.shape
|
||||
Bshape = B.shape
|
||||
assert Ashape == Bshape
|
||||
def f(A,B):
|
||||
if Ashape[0] == Ashape[1] or Bshape[0] == Bshape[1]:
|
||||
#FIXME: Careful if one is transpose of other, would make a matrix
|
||||
return A*B
|
||||
else:
|
||||
return np.dot(A,B)
|
||||
dot = np.vectorize(f, otypes = [np.object])
|
||||
return dot(A,B)
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
A = np.zeros((5,5))
|
||||
|
|
@ -43,6 +74,3 @@ if __name__=='__main__':
|
|||
print B
|
||||
|
||||
assert np.all(unblock(B) == A)
|
||||
|
||||
import ipdb; ipdb.set_trace() # XXX BREAKPOINT
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue