Table Of Contents
Table Of Contents

gemm2

mxnet.ndarray.linalg.gemm2(A=None, B=None, transpose_a=_Null, transpose_b=_Null, alpha=_Null, axis=_Null, out=None, name=None, **kwargs)

Performs general matrix multiplication. Input are tensors A, B, each of dimension n >= 2 and having the same shape on the leading n-2 dimensions.

If n=2, the BLAS3 function gemm is performed:

out = alpha * op(A) * op(B)

Here alpha is a scalar parameter and op() is either the identity or the matrix transposition (depending on transpose_a, transpose_b).

If n>2, gemm is performed separately for a batch of matrices. The column indices of the matrices are given by the last dimensions of the tensors, the row indices by the axis specified with the axis parameter. By default, the trailing two dimensions will be used for matrix encoding.

For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes calls. For example let A, B be 5 dimensional tensors. Then gemm(A, B, axis=1) is equivalent to

A1 = swapaxes(A, dim1=1, dim2=3) B1 = swapaxes(B, dim1=1, dim2=3) C = gemm2(A1, B1) C = swapaxis(C, dim1=1, dim2=3)

without the overhead of the additional swapaxis operations.

Note

The operator supports float32 and float64 data types only.

Examples:

// Single matrix multiply
A = [[1.0, 1.0], [1.0, 1.0]]
B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]
gemm2(A, B, transpose_b=True, alpha=2.0)
         = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]]

// Batch matrix multiply
A = [[[1.0, 1.0]], [[0.1, 0.1]]]
B = [[[1.0, 1.0]], [[0.1, 0.1]]]
gemm2(A, B, transpose_b=True, alpha=2.0)
        = [[[4.0]], [[0.04 ]]]

Defined in src/operator/tensor/la_op.cc:L151

Parameters:
  • A (NDArray) – Tensor of input matrices
  • B (NDArray) – Tensor of input matrices
  • transpose_a (boolean, optional, default=0) – Multiply with transposed of first input (A).
  • transpose_b (boolean, optional, default=0) – Multiply with transposed of second input (B).
  • alpha (double, optional, default=1) – Scalar factor multiplied with A*B.
  • axis (int, optional, default='-2') – Axis corresponding to the matrix row indices.
  • out (NDArray, optional) – The output NDArray to hold the result.
Returns:

out – The output of this function.

Return type:

NDArray or list of NDArrays