Table Of Contents
Table Of Contents


class mxnet.gluon.Trainer(params, optimizer, optimizer_params=None, kvstore='device', compression_params=None, update_on_kvstore=None)[source]

Applies an Optimizer on a set of Parameters. Trainer should be used together with autograd.

  • params (ParameterDict) – The set of parameters to optimize.
  • optimizer (str or Optimizer) – The optimizer to use. See help on Optimizer for a list of available optimizers.
  • optimizer_params (dict) – Key-word arguments to be passed to optimizer constructor. For example, {‘learning_rate’: 0.1}. All optimizers accept learning_rate, wd (weight decay), clip_gradient, and lr_scheduler. See each optimizer’s constructor for a list of additional supported arguments.
  • kvstore (str or KVStore) – kvstore type for multi-gpu and distributed training. See help on mxnet.kvstore.create for more information.
  • compression_params (dict) – Specifies type of gradient compression and additional arguments depending on the type of compression being used. For example, 2bit compression requires a threshold. Arguments would then be {‘type’:‘2bit’, ‘threshold’:0.5} See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
  • update_on_kvstore (bool, default None) – Whether to perform parameter updates on kvstore. If None, then trainer will choose the more suitable option depending on the type of kvstore.
  • Properties
  • ----------
  • learning_rate (float) – The current learning rate of the optimizer. Given an Optimizer object optimizer, its learning rate can be accessed as optimizer.learning_rate.

Updating parameters

Trainer.step(batch_size[, ignore_stale_grad]) Makes one step of parameter update.
Trainer.allreduce_grads() For each parameter, reduce the gradients from different contexts.
Trainer.update(batch_size[, ignore_stale_grad]) Makes one step of parameter update.

Trainer States

Trainer.load_states(fname) Loads trainer states (e.g.
Trainer.save_states(fname) Saves trainer states (e.g.

Learning rate

Trainer.set_learning_rate(lr) Sets a new learning rate of the optimizer.