Table Of Contents
Table Of Contents

mxnet.executor_manager.DataParallelExecutorManager

class mxnet.executor_manager.DataParallelExecutorManager(symbol, ctx, train_data, arg_names, param_names, aux_names, work_load_list=None, logger=None, sym_gen=None)[source]

Helper class to manage multiple executors for data parallelism.

Parameters
  • symbol (Symbol) – Output symbol.

  • ctx (list of Context) – Devices to run on.

  • param_names (list of str) – Name of all trainable parameters of the network.

  • arg_names (list of str) – Name of all arguments of the network.

  • aux_names (list of str) – Name of all auxiliary states of the network.

  • train_data (DataIter) – Training data iterator.

  • work_load_list (list of float or int, optional) – The list of work load for different devices, in the same order as ctx.

  • logger (logging logger) – When not specified, default logger will be used.

  • sym_gen (A function that generate new Symbols depending on different) – input shapes. Used only for bucketing.

__init__(symbol, ctx, train_data, arg_names, param_names, aux_names, work_load_list=None, logger=None, sym_gen=None)[source]

Initialize self. See help(type(self)) for accurate signature.

Methods

__init__(symbol, ctx, train_data, arg_names, …)

Initialize self.

backward()

Run backward on the current executor.

copy_to(arg_params, aux_params)

Copy data from each executor to `arg_params and aux_params.

forward([is_train])

Run forward on the current executor.

install_monitor(monitor)

Install monitor on all executors.

load_data_batch(data_batch)

Load data and labels into arrays.

set_params(arg_params, aux_params)

Set parameter and aux values.

update_metric(metric, labels[, pre_sliced])

Update metric with the current executor.

Attributes

aux_arrays

Shared aux states.

grad_arrays

Shared gradient arrays.

param_arrays

Shared parameter arrays.