Table Of Contents
Table Of Contents

cond

mxnet.ndarray.contrib.cond(pred, then_func, else_func)[source]

Run an if-then-else using user-defined condition and computation

This operator simulates a if-like branch which chooses to do one of the two customized computations according to the specified condition.

pred is a scalar MXNet NDArray, indicating which branch of computation should be used.

then_func is a user-defined function, used as computation of the then branch. It produces outputs, which is a list of NDArrays. The signature of then_func should be then_func() => NDArray or nested List[NDArray].

else_func is a user-defined function, used as computation of the else branch. It produces outputs, which is a list of NDArrays. The signature of else_func should be else_func() => NDArray or nested List[NDArray].

The outputs produces by then_func and else_func should have the same number of elements, all of which should be in the same shape, of the same dtype and stype.

This function returns a list of symbols, representing the computation result.

Parameters:
  • pred (a MXNet NDArray representing a scalar.) – The branch condition.
  • then_func (a Python function.) – The computation to be executed if pred is true.
  • else_func (a Python function.) – The computation to be executed if pred is false.
Returns:

outputs

Return type:

an NDArray or nested lists of NDArrays, representing the result of computation.

Examples

>>> a, b = mx.nd.array([1]), mx.nd.array([2])
>>> pred = a * b < 5
>>> then_func = lambda a, b: (a + 5) * (b + 5)
>>> else_func = lambda a, b: (a - 5) * (b - 5)
>>> outputs = mx.nd.contrib.cond(pred, then_func, else_func)
>>> outputs[0]
[42.]
<NDArray 1 @cpu(0)>