Table Of Contents
Table Of Contents


class mxnet.module.BaseModule(logger=<module 'logging' from '/var/lib/jenkins/miniconda3/envs/mxnet-docs/lib/python3.7/logging/'>)[source]

The base class of a module.

A module represents a computation component. One can think of module as a computation machine. A module can execute forward and backward passes and update parameters in a model. We aim to make the APIs easy to use, especially in the case when we need to use the imperative API to work with multiple modules (e.g. stochastic depth network).

A module has several states:

  • Initial state: Memory is not allocated yet, so the module is not ready for computation yet.
  • Binded: Shapes for inputs, outputs, and parameters are all known, memory has been allocated, and the module is ready for computation.
  • Parameters are initialized: For modules with parameters, doing computation before initializing the parameters might result in undefined outputs.
  • Optimizer is installed: An optimizer can be installed to a module. After this, the parameters of the module can be updated according to the optimizer after gradients are computed (forward-backward).

In order for a module to interact with others, it must be able to report the following information in its initial state (before binding):

  • data_names: list of type string indicating the names of the required input data.
  • output_names: list of type string indicating the names of the required outputs.

After binding, a module should be able to report the following richer information:

  • state information
    • binded: bool, indicates whether the memory buffers needed for computation have been allocated.
    • for_training: whether the module is bound for training.
    • params_initialized: bool, indicates whether the parameters of this module have been initialized.
    • optimizer_initialized: bool, indicates whether an optimizer is defined and initialized.
    • inputs_need_grad: bool, indicates whether gradients with respect to the input data are needed. Might be useful when implementing composition of modules.
  • input/output information
    • data_shapes: a list of (name, shape). In theory, since the memory is allocated, we could directly provide the data arrays. But in the case of data parallelism, the data arrays might not be of the same shape as viewed from the external world.
    • label_shapes: a list of (name, shape). This might be [] if the module does not need labels (e.g. it does not contains a loss function at the top), or a module is not bound for training.
    • output_shapes: a list of (name, shape) for outputs of the module.
  • parameters (for modules with parameters)
    • get_params(): return a tuple (arg_params, aux_params). Each of those is a dictionary of name to NDArray mapping. Those NDArray always lives on CPU. The actual parameters used for computing might live on other devices (GPUs), this function will retrieve (a copy of) the latest parameters.
    • set_params(arg_params, aux_params): assign parameters to the devices doing the computation.
    • init_params(...): a more flexible interface to assign or initialize the parameters.
  • setup
    • bind(): prepare environment for computation.
    • init_optimizer(): install optimizer for parameter updating.
    • prepare(): prepare the module based on the current data batch.
  • computation
    • forward(data_batch): forward operation.
    • backward(out_grads=None): backward operation.
    • update(): update parameters according to installed optimizer.
    • get_outputs(): get outputs of the previous forward operation.
    • get_input_grads(): get the gradients with respect to the inputs computed in the previous backward operation.
    • update_metric(metric, labels, pre_sliced=False): update performance metric for the previous forward computed results.
  • other properties (mostly for backward compatibility)
    • symbol: the underlying symbolic graph for this module (if any) This property is not necessarily constant. For example, for BucketingModule, this property is simply the current symbol being used. For other modules, this value might not be well defined.

When those intermediate-level API are implemented properly, the following high-level API will be automatically available for a module:

  • fit: train the module parameters on a data set.
  • predict: run prediction on a data set and collect outputs.
  • score: run prediction on a data set and evaluate performance.


>>> # An example of creating a mxnet module.
>>> import mxnet as mx
>>> data = mx.symbol.Variable('data')
>>> fc1  = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
>>> act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
>>> fc2  = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
>>> act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
>>> fc3  = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
>>> out  = mx.symbol.SoftmaxOutput(fc3, name = 'softmax')
>>> mod = mx.mod.Module(out)
__init__(logger=<module 'logging' from '/var/lib/jenkins/miniconda3/envs/mxnet-docs/lib/python3.7/logging/'>)[source]

Initialize self. See help(type(self)) for accurate signature.


__init__([logger]) Initialize self.
backward([out_grads]) Backward computation.
bind(data_shapes[, label_shapes, …]) Binds the symbols to construct executors.
fit(train_data[, eval_data, eval_metric, …]) Trains the module parameters.
forward(data_batch[, is_train]) Forward computation.
forward_backward(data_batch) A convenient function that calls both forward and backward.
get_input_grads([merge_multi_context]) Gets the gradients to the inputs, computed in the previous backward computation.
get_outputs([merge_multi_context]) Gets outputs of the previous forward computation.
get_params() Gets parameters, those are potentially copies of the the actual parameters used to do computation on the device.
get_states([merge_multi_context]) Gets states from all devices
init_optimizer([kvstore, optimizer, …]) Installs and initializes optimizers, as well as initialize kvstore for
init_params([initializer, arg_params, …]) Initializes the parameters and auxiliary states.
install_monitor(mon) Installs monitor on all executors.
iter_predict(eval_data[, num_batch, reset, …]) Iterates over predictions.
load_params(fname) Loads model parameters from file.
predict(eval_data[, num_batch, …]) Runs prediction and collects the outputs.
prepare(data_batch[, sparse_row_id_fn]) Prepares the module for processing a data batch.
save_params(fname) Saves model parameters to file.
score(eval_data, eval_metric[, num_batch, …]) Runs prediction on eval_data and evaluates the performance according to the given eval_metric.
set_params(arg_params, aux_params[, …]) Assigns parameter and aux state values.
set_states([states, value]) Sets value for states.
update() Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch.
update_metric(eval_metric, labels[, pre_sliced]) Evaluates and accumulates evaluation metric on outputs of the last forward computation.


data_names A list of names for data required by this module.
data_shapes A list of (name, shape) pairs specifying the data inputs to this module.
label_shapes A list of (name, shape) pairs specifying the label inputs to this module.
output_names A list of names for the outputs of this module.
output_shapes A list of (name, shape) pairs specifying the outputs of this module.
symbol Gets the symbol associated with this module.