A brief talk through Matrix Multiplication in Keras with Tensorflow as Backend
Matrix Multiplication
The matrix multiplication is performed with tf.matmul
in Tensorflow or K.dot
in Keras :
from keras import backend as K
a = K.ones((3,4))
b = K.ones((4,5))
c = K.dot(a, b)
print(c.shape)
or
import tensorflow as tf
a = tf.ones((3,4))
b = tf.ones((4,5))
c = tf.matmul(a, b)
print(c.shape)
returns a tensor of shape (3,5) in both cases. It is always simple when tensor dimension is no greater than 2, even 3. However, a compatible way is what we persue. When dimension is higher(introducing higher dimension data and batch data), batch matrix multiplication is what we need.
Simple Batch Matrix Multiplication : tf.matmul or K.batch_dot
There is another operator, K.batch_dot
that works the same as tf.matmul
from keras import backend as K
a = K.ones((9, 8, 7, 4, 2))
b = K.ones((9, 8, 7, 2, 5))
c = K.batch_dot(a, b)
print(c.shape)
or
import tensorflow as tf
a = tf.ones((9, 8, 7, 4, 2))
b = tf.ones((9, 8, 7, 2, 5))
c = tf.matmul(a, b)
print(c.shape)
returns a tensor of shape (9, 8, 7, 4, 5) in both cases.
So, here the multiplication has been performed considering (9,8,7) as the batch size or non-spatial dimension. Data is considered as (B1,…,Bn,C,H,W) format. Spatial dimension of tensor is in the last two indices. Here , spatial dimension of tensor a is (4,2) and b is (2,5).
However, if channel as the last dimension(data default format) will cause spatial multiplication( height and width matrix multiplication) error in K.batch_dot
. Thus, we need K.permute_dimensions
as a preprocessing step for Batch Matrix Multiplication.
Batch Matrix Multiplication:K.permute_dimensions and K.batch_dot
Take a multi-channel data as example.
from keras import backend as K
"""
batch_a [10,512,256,3] , 10 as batch number, 512x256 as height x width, 3-channel
batch_b [10,256,512,3] , 10 as batch number, 256x512 as height x width, 3-channel
"""
a_t = K.permute_dimensions(a, (0,3,1,2)) # K.int_shape(a_t)=(10,3,512,256)
b_t = K.permute_dimensions(b, (0,3,1,2)) # K.int_shape(b_t)=(10,3,256,512)
c_t = K.batch_dot(a_t, b_t, axes=(3, 2))
K.int_shape(c) # (10, 3, 512, 512)
c = K.permute_dimensions(c_t, (0,2,3,1)) # K.int_shape(c)=(10,512,512,3)
Case suspends for better answer if anyone please to enlight more.