Table Of Contents
Table Of Contents


mxnet.ndarray.linalg.gemm(A=None, B=None, C=None, transpose_a=_Null, transpose_b=_Null, alpha=_Null, beta=_Null, axis=_Null, out=None, name=None, **kwargs)

Performs general matrix multiplication and accumulation. Input are tensors A, B, C, each of dimension n >= 2 and having the same shape on the leading n-2 dimensions.

If n=2, the BLAS3 function gemm is performed:

out = alpha * op(A) * op(B) + beta * C

Here, alpha and beta are scalar parameters, and op() is either the identity or matrix transposition (depending on transpose_a, transpose_b).

If n>2, gemm is performed separately for a batch of matrices. The column indices of the matrices are given by the last dimensions of the tensors, the row indices by the axis specified with the axis parameter. By default, the trailing two dimensions will be used for matrix encoding.

For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes calls. For example let A, B, C be 5 dimensional tensors. Then gemm(A, B, C, axis=1) is equivalent to

A1 = swapaxes(A, dim1=1, dim2=3) B1 = swapaxes(B, dim1=1, dim2=3) C = swapaxes(C, dim1=1, dim2=3) C = gemm(A1, B1, C) C = swapaxis(C, dim1=1, dim2=3)

without the overhead of the additional swapaxis operations.

When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use pseudo-float16 precision (float32 math with float16 I/O) precision in order to use Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.


The operator supports float32 and float64 data types only.


// Single matrix multiply-add
A = [[1.0, 1.0], [1.0, 1.0]]
B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]
C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
gemm(A, B, C, transpose_b=True, alpha=2.0, beta=10.0)
        = [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]]

// Batch matrix multiply-add
A = [[[1.0, 1.0]], [[0.1, 0.1]]]
B = [[[1.0, 1.0]], [[0.1, 0.1]]]
C = [[[10.0]], [[0.01]]]
gemm(A, B, C, transpose_b=True, alpha=2.0 , beta=10.0)
        = [[[104.0]], [[0.14]]]

Defined in src/operator/tensor/

  • A (NDArray) – Tensor of input matrices

  • B (NDArray) – Tensor of input matrices

  • C (NDArray) – Tensor of input matrices

  • transpose_a (boolean, optional, default=0) – Multiply with transposed of first input (A).

  • transpose_b (boolean, optional, default=0) – Multiply with transposed of second input (B).

  • alpha (double, optional, default=1) – Scalar factor multiplied with A*B.

  • beta (double, optional, default=1) – Scalar factor multiplied with C.

  • axis (int, optional, default='-2') – Axis corresponding to the matrix rows.

  • out (NDArray, optional) – The output NDArray to hold the result.


out – The output of this function.

Return type

NDArray or list of NDArrays