Table Of Contents
Table Of Contents

mxnet.io.NDArrayIter

class mxnet.io.NDArrayIter(data, label=None, batch_size=1, shuffle=False, last_batch_handle='pad', data_name='data', label_name='softmax_label')[source]

Returns an iterator for mx.nd.NDArray, numpy.ndarray, h5py.Dataset mx.nd.sparse.CSRNDArray or scipy.sparse.csr_matrix.

Examples

>>> data = np.arange(40).reshape((10,2,2))
>>> labels = np.ones([10, 1])
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='discard')
>>> for batch in dataiter:
...     print batch.data[0].asnumpy()
...     batch.data[0].shape
...
[[[ 36.  37.]
  [ 38.  39.]]
 [[ 16.  17.]
  [ 18.  19.]]
 [[ 12.  13.]
  [ 14.  15.]]]
(3L, 2L, 2L)
[[[ 32.  33.]
  [ 34.  35.]]
 [[  4.   5.]
  [  6.   7.]]
 [[ 24.  25.]
  [ 26.  27.]]]
(3L, 2L, 2L)
[[[  8.   9.]
  [ 10.  11.]]
 [[ 20.  21.]
  [ 22.  23.]]
 [[ 28.  29.]
  [ 30.  31.]]]
(3L, 2L, 2L)
>>> dataiter.provide_data # Returns a list of `DataDesc`
[DataDesc[data,(3, 2L, 2L),<type 'numpy.float32'>,NCHW]]
>>> dataiter.provide_label # Returns a list of `DataDesc`
[DataDesc[softmax_label,(3, 1L),<type 'numpy.float32'>,NCHW]]

In the above example, data is shuffled as shuffle parameter is set to True and remaining examples are discarded as last_batch_handle parameter is set to discard.

Usage of last_batch_handle parameter:

>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='pad')
>>> batchidx = 0
>>> for batch in dataiter:
...     batchidx += 1
...
>>> batchidx  # Padding added after the examples read are over. So, 10/3+1 batches are created.
4
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='discard')
>>> batchidx = 0
>>> for batch in dataiter:
...     batchidx += 1
...
>>> batchidx # Remaining examples are discarded. So, 10/3 batches are created.
3
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, last_batch_handle='roll_over')
>>> batchidx = 0
>>> for batch in dataiter:
...     batchidx += 1
...
>>> batchidx # Remaining examples are rolled over to the next iteration.
3
>>> dataiter.reset()
>>> dataiter.next().data[0].asnumpy()
[[[ 36.  37.]
  [ 38.  39.]]
 [[ 0.  1.]
  [ 2.  3.]]
 [[ 4.  5.]
  [ 6.  7.]]]
(3L, 2L, 2L)

NDArrayIter also supports multiple input and labels.

>>> data = {'data1':np.zeros(shape=(10,2,2)), 'data2':np.zeros(shape=(20,2,2))}
>>> label = {'label1':np.zeros(shape=(10,1)), 'label2':np.zeros(shape=(20,1))}
>>> dataiter = mx.io.NDArrayIter(data, label, 3, True, last_batch_handle='discard')

NDArrayIter also supports mx.nd.sparse.CSRNDArray with last_batch_handle set to discard.

>>> csr_data = mx.nd.array(np.arange(40).reshape((10,4))).tostype('csr')
>>> labels = np.ones([10, 1])
>>> dataiter = mx.io.NDArrayIter(csr_data, labels, 3, last_batch_handle='discard')
>>> [batch.data[0] for batch in dataiter]
[
<CSRNDArray 3x4 @cpu(0)>,
<CSRNDArray 3x4 @cpu(0)>,
<CSRNDArray 3x4 @cpu(0)>]
Parameters
  • data (array or list of array or dict of string to array) – The input data.

  • label (array or list of array or dict of string to array, optional) – The input label.

  • batch_size (int) – Batch size of data.

  • shuffle (bool, optional) – Whether to shuffle the data. Only supported if no h5py.Dataset inputs are used.

  • last_batch_handle (str, optional) – How to handle the last batch. This parameter can be ‘pad’, ‘discard’ or ‘roll_over’. If ‘pad’, the last batch will be padded with data starting from the begining If ‘discard’, the last batch will be discarded If ‘roll_over’, the remaining elements will be rolled over to the next iteration and note that it is intended for training and can cause problems if used for prediction.

  • data_name (str, optional) – The data name.

  • label_name (str, optional) – The label name.

__init__(data, label=None, batch_size=1, shuffle=False, last_batch_handle='pad', data_name='data', label_name='softmax_label')[source]

Initialize self. See help(type(self)) for accurate signature.

Methods

__init__(data[, label, batch_size, shuffle, …])

Initialize self.

getdata()

Get data.

getindex()

Get index of the current batch.

getlabel()

Get label.

getpad()

Get pad value of DataBatch.

hard_reset()

Ignore roll over data and set to start.

iter_next()

Increments the coursor by batch_size for next batch and check current cursor if it exceed the number of data points.

next()

Returns the next batch of data.

reset()

Resets the iterator to the beginning of the data.

Attributes

provide_data

The name and shape of data provided by this iterator.

provide_label

The name and shape of label provided by this iterator.