Table Of Contents
Table Of Contents

check_symbolic_forward

mxnet.test_utils.check_symbolic_forward(sym, location, expected, rtol=0.0001, atol=None, aux_states=None, ctx=None, equal_nan=False, dtype=<class 'numpy.float32'>)[source]

Compares a symbol’s forward results with the expected ones. Prints error messages if the forward results are not the same as the expected ones.

Parameters:
  • sym (Symbol) – output symbol
  • location (list of np.ndarray or dict of str to np.ndarray) –

    The evaluation point

    • if type is list of np.ndarray
      Contains all the numpy arrays corresponding to sym.list_arguments().
    • if type is dict of str to np.ndarray
      Contains the mapping between argument names and their values.
  • expected (list of np.ndarray or dict of str to np.ndarray) –

    The expected output value

    • if type is list of np.ndarray
      Contains arrays corresponding to exe.outputs.
    • if type is dict of str to np.ndarray
      Contains mapping between sym.list_output() and exe.outputs.
  • check_eps (float, optional) – Relative error to check to.
  • aux_states (list of np.ndarray of dict, optional) –
    • if type is list of np.ndarray
      Contains all the NumPy arrays corresponding to sym.list_auxiliary_states
    • if type is dict of str to np.ndarray
      Contains the mapping between names of auxiliary states and their values.
  • ctx (Context, optional) – running context
  • dtype (np.float16 or np.float32 or np.float64) – Datatype for mx.nd.array.
  • equal_nan (Boolean) – if True, nan is a valid value for checking equivalency (ie nan == nan)

Example

>>> shape = (2, 2)
>>> lhs = mx.symbol.Variable('lhs')
>>> rhs = mx.symbol.Variable('rhs')
>>> sym_dot = mx.symbol.dot(lhs, rhs)
>>> mat1 = np.array([[1, 2], [3, 4]])
>>> mat2 = np.array([[5, 6], [7, 8]])
>>> ret_expected = np.array([[19, 22], [43, 50]])
>>> check_symbolic_forward(sym_dot, [mat1, mat2], [ret_expected])