Table Of Contents
Table Of Contents

while_loop

mxnet.ndarray.contrib.while_loop(cond, func, loop_vars, max_iterations=None)[source]

Run a while loop with user-defined computation and loop condition.

This operator simulates a while loop which iterately does customized computation as long as the condition is satisfied.

loop_vars is a list of NDArrays on which the computation uses.

cond is a user-defined function, used as the loop condition. It consumes loop_vars, and produces a scalar MXNet NDArray, indicating the termination of the loop. The loop ends when cond returns false (zero). The cond is variadic, and its signature should be cond(*loop_vars) => NDArray.

func is a user-defined function, used as the loop body. It also consumes loop_vars, and produces step_output and new_loop_vars at each step. In each step, step_output should contain the same number elements. Through all steps, the i-th element of step_output should have the same shape and dtype. Also, new_loop_vars should contain the same number of elements as loop_vars, and the corresponding element should have the same shape and dtype. The func is variadic, and its signature should be func(*loop_vars) => (NDArray or nested List[NDArray] step_output, NDArray or nested List[NDArray] new_loop_vars).

max_iterations is a scalar that defines the maximum number of iterations allowed.

This function returns two lists. The first list has the length of |step_output|, in which the i-th element are all i-th elements of step_output from all steps, stacked along axis 0. The second list has the length of |loop_vars|, which represents final states of loop variables.

Warning

For now, the axis 0 of all NDArrays in the first list are max_iterations, due to lack of dynamic shape inference.

Warning

When cond is never satisfied, we assume step_output is empty, because it cannot be inferred. This is different from the symbolic version.

Parameters:
  • cond (a Python function.) – The loop condition.
  • func (a Python function.) – The loop body.
  • loop_vars (an NDArray or nested lists of NDArrays.) – The initial values of the loop variables.
  • max_iterations (a python int.) – Maximum number of iterations.
Returns:

  • outputs (an NDArray or nested lists of NDArrays) – stacked output from each step
  • states (an NDArray or nested lists of NDArrays) – final state

Examples

>>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: ([i + s], [i + 1, s + i])
>>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64"))
>>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars, max_iterations=10)
>>> outputs
[
[[ 1]
[ 2]
[ 4]
[ 7]
[11]
[16]
[...]  # undefined value
[...]
[...]
[...]]
<NDArray 6x1 @cpu(0)>]
>>> states
[
[6]
<NDArray 1 @cpu(0)>,
[16]
<NDArray 1 @cpu(0)>]