Table Of Contents
Table Of Contents


class mxnet.module.BucketingModule(sym_gen, default_bucket_key=None, logger=<module 'logging' from '/var/lib/jenkins/miniconda3/envs/mxnet-docs/lib/python3.7/logging/'>, context=cpu(0), work_load_list=None, fixed_param_names=None, state_names=None, group2ctxs=None, compression_params=None)[source]

This module helps to deal efficiently with varying-length inputs.

  • sym_gen (function) – A function when called with a bucket key, returns a triple (symbol, data_names, label_names).

  • default_bucket_key (str (or any python object)) – The key for the default bucket.

  • logger (Logger) –

  • context (Context or list of Context) – Defaults to mx.cpu()

  • work_load_list (list of number) – Defaults to None, indicating uniform workload.

  • fixed_param_names (list of str) – Defaults to None, indicating no network parameters are fixed.

  • state_names (list of str) – States are similar to data and label, but not provided by data iterator. Instead they are initialized to 0 and can be set by set_states()

  • group2ctxs (dict of str to context or list of context,) – or list of dict of str to context Default is None. Mapping the ctx_group attribute to the context assignment.

  • compression_params (dict) – Specifies type of gradient compression and additional arguments depending on the type of compression being used. For example, 2bit compression requires a threshold. Arguments would then be {‘type’:‘2bit’, ‘threshold’:0.5} See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.

__init__(sym_gen, default_bucket_key=None, logger=<module 'logging' from '/var/lib/jenkins/miniconda3/envs/mxnet-docs/lib/python3.7/logging/'>, context=cpu(0), work_load_list=None, fixed_param_names=None, state_names=None, group2ctxs=None, compression_params=None)[source]

Initialize self. See help(type(self)) for accurate signature.


__init__(sym_gen[, default_bucket_key, …])

Initialize self.


Backward computation.

bind(data_shapes[, label_shapes, …])

Binding for a BucketingModule means setting up the buckets and binding the executor for the default bucket key.

fit(train_data[, eval_data, eval_metric, …])

Trains the module parameters.

forward(data_batch[, is_train])

Forward computation.


A convenient function that calls both forward and backward.


Gets the gradients with respect to the inputs of the module.


Gets outputs from a previous forward computation.


Gets current parameters.


Gets states from all devices.

init_optimizer([kvstore, optimizer, …])

Installs and initializes optimizers.

init_params([initializer, arg_params, …])

Initializes parameters.


Installs monitor on all executors

iter_predict(eval_data[, num_batch, reset, …])

Iterates over predictions.


Loads model parameters from file.

predict(eval_data[, num_batch, …])

Runs prediction and collects the outputs.

prepare(data_batch[, sparse_row_id_fn])

Prepares the module for processing a data batch.


Saves model parameters to file.

score(eval_data, eval_metric[, num_batch, …])

Runs prediction on eval_data and evaluates the performance according to the given eval_metric.

set_params(arg_params, aux_params[, …])

Assigns parameters and aux state values.

set_states([states, value])

Sets value for states.

switch_bucket(bucket_key, data_shapes[, …])

Switches to a different bucket.


Updates parameters according to installed optimizer and the gradient computed in the previous forward-backward cycle.

update_metric(eval_metric, labels[, pre_sliced])

Evaluates and accumulates evaluation metric on outputs of the last forward computation.



A list of names for data required by this module.


Get data shapes.


Get label shapes.


A list of names for the outputs of this module.


Gets output shapes.


The symbol of the current bucket being used.