diff --git a/GPy/util/block_matrices.py b/GPy/util/block_matrices.py index 95920868..cdbb1b0b 100644 --- a/GPy/util/block_matrices.py +++ b/GPy/util/block_matrices.py @@ -17,6 +17,23 @@ def get_blocks(A, blocksizes): count_i += i return B +def get_block_shapes(B): + assert B.dtype is np.dtype('object'), "Must be a block matrix" + return [B[b,b].shape[0] for b in range(0, B.shape[0])] + +def unblock(B): + assert B.dtype is np.dtype('object'), "Must be a block matrix" + block_shapes = get_block_shapes(B) + num_elements = np.sum(block_shapes) + A = np.empty(shape=(num_elements, num_elements)) + count_i = 0 + for Bi, i in enumerate(block_shapes): + count_j = 0 + for Bj, j in enumerate(block_shapes): + A[count_i:count_i + i, count_j:count_j + j] = B[Bi, Bj] + count_j += j + count_i += i + return A if __name__=='__main__': @@ -24,3 +41,8 @@ if __name__=='__main__': B = get_blocks(A,[2,3]) B[0,0] += 7 print B + + assert np.all(unblock(B) == A) + + import ipdb; ipdb.set_trace() # XXX BREAKPOINT +