A brief talk through Matrix Multiplication in Keras with Tensorflow as Backend

Matrix Multiplication

The matrix multiplication is performed with tf.matmul in Tensorflow or K.dot in Keras :

    from keras import backend as K
    a = K.ones((3,4))
    b = K.ones((4,5))
    c = K.dot(a, b)
    print(c.shape)

or

    import tensorflow as tf
    a = tf.ones((3,4))
    b = tf.ones((4,5))
    c = tf.matmul(a, b)
    print(c.shape)

returns a tensor of shape (3,5) in both cases. It is always simple when tensor dimension is no greater than 2, even 3. However, a compatible way is what we persue. When dimension is higher(introducing higher dimension data and batch data), batch matrix multiplication is what we need.

Simple Batch Matrix Multiplication : tf.matmul or K.batch_dot

There is another operator, K.batch_dot that works the same as tf.matmul

    from keras import backend as K
    a = K.ones((9, 8, 7, 4, 2))
    b = K.ones((9, 8, 7, 2, 5))
    c = K.batch_dot(a, b)
    print(c.shape)

or

    import tensorflow as tf
    a = tf.ones((9, 8, 7, 4, 2))
    b = tf.ones((9, 8, 7, 2, 5))
    c = tf.matmul(a, b)
    print(c.shape)

returns a tensor of shape (9, 8, 7, 4, 5) in both cases.

So, here the multiplication has been performed considering (9,8,7) as the batch size or non-spatial dimension. Data is considered as (B1,…,Bn,C,H,W) format. Spatial dimension of tensor is in the last two indices. Here , spatial dimension of tensor a is (4,2) and b is (2,5).

However, if channel as the last dimension(data default format) will cause spatial multiplication( height and width matrix multiplication) error in K.batch_dot. Thus, we need K.permute_dimensions as a preprocessing step for Batch Matrix Multiplication.

Batch Matrix Multiplication:K.permute_dimensions and K.batch_dot

Take a multi-channel data as example.

    from keras import backend as K
    """ 
        batch_a [10,512,256,3] , 10 as batch number, 512x256 as height x width, 3-channel 
        
        batch_b [10,256,512,3] , 10 as batch number, 256x512 as height x width, 3-channel 
    """
    a_t = K.permute_dimensions(a, (0,3,1,2))  # K.int_shape(a_t)=(10,3,512,256) 
    b_t = K.permute_dimensions(b, (0,3,1,2))  # K.int_shape(b_t)=(10,3,256,512)
    c_t = K.batch_dot(a_t, b_t, axes=(3, 2))
    K.int_shape(c) # (10, 3, 512, 512)
    c = K.permute_dimensions(c_t, (0,2,3,1)) # K.int_shape(c)=(10,512,512,3)

Case suspends for better answer if anyone please to enlight more.

comments powered by Disqus