Skip to content

Deployer

Deployer

Handles low-level operations to support Trainer and Predictor, e.g., automatic data/model parallelism, distributed checkpointing, data processing, logging, randomness controlling, etc.

Attributes:

Name Type Description
workdir str

Working directory for saving checkpoints and logs.

mesh jax Mesh

Mesh used for model sharding.

Source code in redco/deployers/deployer.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
class Deployer:
    """ Handles low-level operations to support Trainer and Predictor,
        e.g., automatic data/model parallelism, distributed checkpointing,
        data processing, logging, randomness controlling, etc.

    Attributes:
        workdir (str): Working directory for saving checkpoints and logs.
        mesh (jax Mesh): Mesh used for model sharding.
    """
    def __init__(self,
                 jax_seed,
                 n_model_shards=1,
                 verbose=True,
                 workdir=None,
                 n_processes=None,
                 host0_address=None,
                 host0_port=None,
                 process_id=None,
                 n_local_devices=None,
                 run_tensorboard=False,
                 wandb_init_kwargs=None):
        """ Initializes a Deployer.

        Args:
            jax_seed (int): Seed for random number generation.
            n_model_shards (int): Number of shards for running large model.
            verbose (bool): Whether to enable verbose logging.
            workdir (str):  Directory for saving logs and checkpoints.
            n_processes (int):  For multi-host, number of processes/nodes.
            host0_address (str):  For multi-host, address of the host0.
            host0_port (int): For multi-host, port of the host0.
            process_id (int): For multi-host, index of the current process.
            n_local_devices (int): For multi-host, number of local devices.
            run_tensorboard (bool):  Whether to enable TensorBoard logging.
            wandb_init_kwargs (dict): wandb.init arguments if using wandb.
        """
        if n_processes is None:
            if 'SLURM_JOB_NUM_NODES' in os.environ:
                n_processes = int(os.environ['SLURM_JOB_NUM_NODES'])
                process_id = int(os.environ['SLURM_NODEID'])
            else:
                n_processes = 1

        if n_processes > 1:
            local_device_ids = None if n_local_devices is None \
                else list(range(n_local_devices))

            if host0_port is None:
                host0_port = DEFAULT_HOST0_PORT

            jax.distributed.initialize(
                coordinator_address=f'{host0_address}:{host0_port}',
                num_processes=n_processes,
                process_id=process_id,
                local_device_ids=local_device_ids)

        if workdir is not None:
            os.makedirs(workdir, exist_ok=True)

        self._verbose = verbose
        self._workdir = workdir
        self._logger = get_logger(verbose=verbose, workdir=workdir)

        if wandb_init_kwargs is not None and jax.process_index() == 0:
            import wandb
            wandb.init(**wandb_init_kwargs)
            self._wandb_log_fn = wandb.log
        else:
            self._wandb_log_fn = None

        if run_tensorboard and jax.process_index() == 0:
            from flax.metrics import tensorboard
            self._summary_writer = tensorboard.SummaryWriter(workdir)
        else:
            self._summary_writer = None

        self.log_info(
            f'Local Devices: {jax.local_device_count()} / {jax.device_count()}')

        self._rng = jax.random.PRNGKey(seed=jax_seed)
        self._mesh = get_mesh(n_model_shards=n_model_shards)
        self._checkpointer = ocp.PyTreeCheckpointer()

    def get_local_global_micro_batch_size(self, per_device_batch_size):
        """Get local/global micro batch sizes based on per-device batch size."""
        if self._mesh is None:
            local_micro_batch_size = \
                per_device_batch_size * jax.local_device_count()
            global_micro_batch_size = \
                local_micro_batch_size * jax.process_count()
        else:
            global_micro_batch_size = local_micro_batch_size = \
                per_device_batch_size * self._mesh.shape['dp']

        return local_micro_batch_size, global_micro_batch_size

    def get_accumulate_grad_batches(
            self, global_batch_size, per_device_batch_size):
        """Calculates the number of gradient accumulation batches."""
        _, global_micro_batch_size = self.get_local_global_micro_batch_size(
            per_device_batch_size=per_device_batch_size)
        assert global_batch_size % global_micro_batch_size == 0
        accumulate_grad_batches = global_batch_size // global_micro_batch_size

        return accumulate_grad_batches

    def get_model_input_batches(self,
                                examples,
                                per_device_batch_size,
                                collate_fn,
                                shuffle,
                                shuffle_rng,
                                desc,
                                is_train=False,
                                accumulate_grad_batches=None):
        """Prepares model input batches from examples.

        Args:
            examples (list): List of input examples.
            per_device_batch_size (int): Batch size per device.
            collate_fn (Callable): Function to collate the examples.
            shuffle (bool): Whether to shuffle the examples.
            shuffle_rng (`jax.numpy.Array`): RNG for randomness of shuffling.
            desc (str): Description in the progress bar.
            is_train (bool): Whether the data is for training.
            accumulate_grad_batches (int): gradient accumulation batches.

        Returns:
            (generator): A python generator of batched model inputs.
        """
        local_micro_batch_size, global_micro_batch_size = \
            self.get_local_global_micro_batch_size(
                per_device_batch_size=per_device_batch_size)

        examples = get_host_examples(
            examples=examples,
            global_micro_batch_size=global_micro_batch_size,
            shuffle=shuffle,
            shuffle_rng=shuffle_rng,
            mesh=self._mesh)

        if not is_train:
            desc = f'{desc} (global_batch_size = {global_micro_batch_size})'
        elif accumulate_grad_batches is None:
            desc = \
                f'{desc} (global_micro_batch_size = {global_micro_batch_size})'
        else:
            desc = (f'{desc} ('
                    f'global_micro_batch_size = {global_micro_batch_size}, '
                    f'accumulate_grad_batches = {accumulate_grad_batches})')

        return get_data_batches(
            examples=examples,
            batch_size=local_micro_batch_size,
            collate_fn=collate_fn,
            mesh=self._mesh,
            desc=desc,
            verbose=self._verbose)

    def get_lr_schedule_fn(self,
                           train_size,
                           per_device_batch_size,
                           n_epochs,
                           learning_rate,
                           schedule_type='linear',
                           warmup_ratio=0.,
                           warmup_steps=None,
                           init_learning_rate=0.,
                           end_learning_rate=0.):
        """Creates a learning rate schedule function.

        Args:
            train_size (int): Number of training examples per epoch.
            per_device_batch_size (int): Batch size per device.
            n_epochs (int): Number of epochs.
            learning_rate (float): Peak learning rate.
            schedule_type (str): Type of lr schedule, "linear" or "cosine".
            warmup_ratio (float): Ratio of lr warmup.
            warmup_steps (int): Number of warmup steps.
            init_learning_rate (float): Initial learning rate before warmup.
            end_learning_rate (float): End learning rate for the schedule.

        Returns:
            (Callable): A lr schedule function, step -> learning rate.
        """
        _, global_micro_batch_size = self.get_local_global_micro_batch_size(
            per_device_batch_size=per_device_batch_size)
        total_train_steps = n_epochs * (train_size // global_micro_batch_size)

        if warmup_steps is None:
            warmup_steps = int(total_train_steps * warmup_ratio)

        return get_lr_schedule_fn(
            schedule_type=schedule_type,
            total_train_steps=total_train_steps,
            warmup_steps=warmup_steps,
            init_learning_rate=init_learning_rate,
            learning_rate=learning_rate,
            end_learning_rate=end_learning_rate)

    def get_sharding_rules(self, params_shape_or_params):
        """Get sharding rules based on the parameter shapes."""
        if self._mesh is None:
            return None
        else:
            sharding_rules = get_sharding_rules(
                params_shape_or_params=params_shape_or_params,
                n_model_shards=self._mesh.shape['mp'])
            return sharding_rules

    def get_params_spec(self, params_shape_or_params, params_sharding_rules):
        """Generates parameter specs based on sharding rules."""
        return get_params_spec(
            params_shape_or_params=params_shape_or_params,
            params_sharding_rules=params_sharding_rules)

    def get_opt_state_spec(
            self, params_shape_or_params, params_spec, optimizer):
        """Get optimizer state specs"""
        return get_opt_state_spec(
            params_shape_or_params=params_shape_or_params,
            params_spec=params_spec,
            optimizer=optimizer)

    def shard_params(self, params, params_spec, desc='params'):
        """Distributes parameters to all devices based on the provided specs."""
        self.log_info(info=f'Sharding {desc} ...')
        return shard_params(
            mesh=self._mesh, params=params, params_spec=params_spec)

    def run_model_step(self, step_fn, input_args):
        """Executes a model step function with the provided inputs."""
        if self._mesh is None:
            return step_fn(*input_args)
        else:
            with self._mesh:
                return step_fn(*input_args)

    def gen_rng(self):
        """Get a new random number generator key and update the random state."""
        self._rng, new_rng = jax.random.split(self._rng)
        return new_rng

    def gen_model_step_rng(self):
        """Get a new random number generator key for distributed model step and
        update the random state.
        """
        rng = self.gen_rng()
        if self.mesh is None:
            rng = jax.random.split(
                rng, num=jax.process_count())[jax.process_index()]
            rng = shard_prng_key(rng)
        return rng

    def log_info(self, info, title=None, step=None):
        """Logs a messages"""
        log_info(
            info=info,
            title=title,
            logger=self._logger,
            summary_writer=self._summary_writer,
            step=step)

    def log_metrics(self, metrics, step):
        """Logs metrics to TensorBoard and Weights and Biases (wandb)."""
        if self._summary_writer is not None:
            for metric_name, value in metrics.items():
                self._summary_writer.scalar(metric_name, value, step=step)

        if self._wandb_log_fn is not None:
            self._wandb_log_fn(metrics, step)

    def save_outputs(self, outputs, desc, step):
        """Saves model outputs to workdir."""
        if self._workdir is not None and jax.process_index() == 0:
            save_outputs(
                workdir=self._workdir,
                outputs=outputs,
                desc=desc,
                step=step,
                logger=self._logger,
                summary_writer=self._summary_writer)

    def save_ckpt(
            self, ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs):
        """Saves a checkpoint to the specified directory.

        Args:
            ckpt_dir (str): Directory to save the checkpoint.
            params (dict): Model parameters.
            opt_state (dict): Optimizer state.
            float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.
            **kwargs (dict): Additional information to be saved into
                info.json, e.g., current training step, epoch index, etc.
        """
        ckpt_dir = os.path.abspath(ckpt_dir)
        self.log_info(f'Saving ckpt to {ckpt_dir} ...')
        save_ckpt(
            ckpt_dir=ckpt_dir,
            checkpointer=self._checkpointer,
            params=params,
            opt_state=opt_state,
            float_dtype=float_dtype,
            rng=self._rng,
            **kwargs)
        self.log_info(f'Ckpt saved into {ckpt_dir}')

    def load_params_shape(self, ckpt_dir):
        """Loads the shape of the parameters from a checkpoint."""
        return load_params_shape(ckpt_dir=ckpt_dir)

    def load_ckpt(self,
                  ckpt_dir,
                  params_sharding_rules=None,
                  optimizer=None,
                  float_dtype=None,
                  load_params=True,
                  load_opt_state=True,
                  update_rng=False):
        """Loads a checkpoint from the specified directory.

        Args:
            ckpt_dir (str): Directory of the checkpoint.
            params_sharding_rules (PyTree): Sharding rules for parameters.
            optimizer (optax optimizer): Optimizer for loading opt_state.
            float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.
            load_params (bool): Whether to load the parameters.
            load_opt_state (bool): Whether to load the optimizer state.
            update_rng (bool): if updating the random state of the deployer.

        Returns:
            (tuple): A tuple with the loaded checkpoint (in a dict with
                `"params"` and `"opt_state"`) and additional information (in a
                dict, usually including `"steps"`, `"epoch_idx"`, and `"rng"`).
        """
        ckpt_dir = os.path.abspath(ckpt_dir)
        self.log_info(f'Loading ckpt from {ckpt_dir} ...')

        params_shape = self.load_params_shape(ckpt_dir=ckpt_dir)

        specs = {}
        if self._mesh is not None:
            if params_sharding_rules is None:
                params_sharding_rules = self.get_sharding_rules(
                    params_shape_or_params=params_shape)

            specs['params'] = self.get_params_spec(
                params_shape_or_params=params_shape,
                params_sharding_rules=params_sharding_rules)
            if optimizer is not None:
                specs['opt_state'] = self.get_opt_state_spec(
                    params_shape_or_params=params_shape,
                    params_spec=specs['params'],
                    optimizer=optimizer)

        ckpt, info = load_ckpt(
            ckpt_dir=ckpt_dir,
            checkpointer=self._checkpointer,
            params_shape_or_params=params_shape,
            optimizer=optimizer,
            float_dtype=float_dtype,
            mesh=self._mesh,
            specs=specs,
            load_params=load_params,
            load_opt_state=load_opt_state)

        for key, value in info.items():
            if not update_rng and key == 'rng':
                continue
            self.log_info(f'{ckpt_dir}::{key} = {value}')

        if update_rng:
            self._rng = info['rng']
            self.log_info(f'rng updated to {self._rng} (by {ckpt_dir})')

        return ckpt, info

    def load_last_ckpt(self,
                       optimizer=None,
                       params_sharding_rules=None,
                       float_dtype=None,
                       load_params=True,
                       load_opt_state=True,
                       update_rng=True):
        """Loads the last checkpoint from the work directory (self.workdir).
        See load_ckpt() for the explanation of arguments.
        """
        try:
            last_ckpt_name = open(
                f'{self._workdir}/ckpts/last_ckpt.txt').read().strip()
        except:
            self.log_info(
                f'{self._workdir}/ckpts/last_ckpt.txt not found. '
                f'no ckpt loaded.')
            return None, None

        return self.load_ckpt(
            ckpt_dir=f'{self._workdir}/ckpts/{last_ckpt_name}',
            optimizer=optimizer,
            float_dtype=float_dtype,
            params_sharding_rules=params_sharding_rules,
            load_params=load_params,
            load_opt_state=load_opt_state,
            update_rng=update_rng)

    @property
    def mesh(self):
        """Returns the mesh for model sharding"""
        return self._mesh

    @property
    def workdir(self):
        """Returns the work directory."""
        return self._workdir

mesh property

Returns the mesh for model sharding

workdir property

Returns the work directory.

__init__(jax_seed, n_model_shards=1, verbose=True, workdir=None, n_processes=None, host0_address=None, host0_port=None, process_id=None, n_local_devices=None, run_tensorboard=False, wandb_init_kwargs=None)

Initializes a Deployer.

Parameters:

Name Type Description Default
jax_seed int

Seed for random number generation.

required
n_model_shards int

Number of shards for running large model.

1
verbose bool

Whether to enable verbose logging.

True
workdir str

Directory for saving logs and checkpoints.

None
n_processes int

For multi-host, number of processes/nodes.

None
host0_address str

For multi-host, address of the host0.

None
host0_port int

For multi-host, port of the host0.

None
process_id int

For multi-host, index of the current process.

None
n_local_devices int

For multi-host, number of local devices.

None
run_tensorboard bool

Whether to enable TensorBoard logging.

False
wandb_init_kwargs dict

wandb.init arguments if using wandb.

None
Source code in redco/deployers/deployer.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def __init__(self,
             jax_seed,
             n_model_shards=1,
             verbose=True,
             workdir=None,
             n_processes=None,
             host0_address=None,
             host0_port=None,
             process_id=None,
             n_local_devices=None,
             run_tensorboard=False,
             wandb_init_kwargs=None):
    """ Initializes a Deployer.

    Args:
        jax_seed (int): Seed for random number generation.
        n_model_shards (int): Number of shards for running large model.
        verbose (bool): Whether to enable verbose logging.
        workdir (str):  Directory for saving logs and checkpoints.
        n_processes (int):  For multi-host, number of processes/nodes.
        host0_address (str):  For multi-host, address of the host0.
        host0_port (int): For multi-host, port of the host0.
        process_id (int): For multi-host, index of the current process.
        n_local_devices (int): For multi-host, number of local devices.
        run_tensorboard (bool):  Whether to enable TensorBoard logging.
        wandb_init_kwargs (dict): wandb.init arguments if using wandb.
    """
    if n_processes is None:
        if 'SLURM_JOB_NUM_NODES' in os.environ:
            n_processes = int(os.environ['SLURM_JOB_NUM_NODES'])
            process_id = int(os.environ['SLURM_NODEID'])
        else:
            n_processes = 1

    if n_processes > 1:
        local_device_ids = None if n_local_devices is None \
            else list(range(n_local_devices))

        if host0_port is None:
            host0_port = DEFAULT_HOST0_PORT

        jax.distributed.initialize(
            coordinator_address=f'{host0_address}:{host0_port}',
            num_processes=n_processes,
            process_id=process_id,
            local_device_ids=local_device_ids)

    if workdir is not None:
        os.makedirs(workdir, exist_ok=True)

    self._verbose = verbose
    self._workdir = workdir
    self._logger = get_logger(verbose=verbose, workdir=workdir)

    if wandb_init_kwargs is not None and jax.process_index() == 0:
        import wandb
        wandb.init(**wandb_init_kwargs)
        self._wandb_log_fn = wandb.log
    else:
        self._wandb_log_fn = None

    if run_tensorboard and jax.process_index() == 0:
        from flax.metrics import tensorboard
        self._summary_writer = tensorboard.SummaryWriter(workdir)
    else:
        self._summary_writer = None

    self.log_info(
        f'Local Devices: {jax.local_device_count()} / {jax.device_count()}')

    self._rng = jax.random.PRNGKey(seed=jax_seed)
    self._mesh = get_mesh(n_model_shards=n_model_shards)
    self._checkpointer = ocp.PyTreeCheckpointer()

gen_model_step_rng()

Get a new random number generator key for distributed model step and update the random state.

Source code in redco/deployers/deployer.py
278
279
280
281
282
283
284
285
286
287
def gen_model_step_rng(self):
    """Get a new random number generator key for distributed model step and
    update the random state.
    """
    rng = self.gen_rng()
    if self.mesh is None:
        rng = jax.random.split(
            rng, num=jax.process_count())[jax.process_index()]
        rng = shard_prng_key(rng)
    return rng

gen_rng()

Get a new random number generator key and update the random state.

Source code in redco/deployers/deployer.py
273
274
275
276
def gen_rng(self):
    """Get a new random number generator key and update the random state."""
    self._rng, new_rng = jax.random.split(self._rng)
    return new_rng

get_accumulate_grad_batches(global_batch_size, per_device_batch_size)

Calculates the number of gradient accumulation batches.

Source code in redco/deployers/deployer.py
131
132
133
134
135
136
137
138
139
def get_accumulate_grad_batches(
        self, global_batch_size, per_device_batch_size):
    """Calculates the number of gradient accumulation batches."""
    _, global_micro_batch_size = self.get_local_global_micro_batch_size(
        per_device_batch_size=per_device_batch_size)
    assert global_batch_size % global_micro_batch_size == 0
    accumulate_grad_batches = global_batch_size // global_micro_batch_size

    return accumulate_grad_batches

get_local_global_micro_batch_size(per_device_batch_size)

Get local/global micro batch sizes based on per-device batch size.

Source code in redco/deployers/deployer.py
118
119
120
121
122
123
124
125
126
127
128
129
def get_local_global_micro_batch_size(self, per_device_batch_size):
    """Get local/global micro batch sizes based on per-device batch size."""
    if self._mesh is None:
        local_micro_batch_size = \
            per_device_batch_size * jax.local_device_count()
        global_micro_batch_size = \
            local_micro_batch_size * jax.process_count()
    else:
        global_micro_batch_size = local_micro_batch_size = \
            per_device_batch_size * self._mesh.shape['dp']

    return local_micro_batch_size, global_micro_batch_size

get_lr_schedule_fn(train_size, per_device_batch_size, n_epochs, learning_rate, schedule_type='linear', warmup_ratio=0.0, warmup_steps=None, init_learning_rate=0.0, end_learning_rate=0.0)

Creates a learning rate schedule function.

Parameters:

Name Type Description Default
train_size int

Number of training examples per epoch.

required
per_device_batch_size int

Batch size per device.

required
n_epochs int

Number of epochs.

required
learning_rate float

Peak learning rate.

required
schedule_type str

Type of lr schedule, "linear" or "cosine".

'linear'
warmup_ratio float

Ratio of lr warmup.

0.0
warmup_steps int

Number of warmup steps.

None
init_learning_rate float

Initial learning rate before warmup.

0.0
end_learning_rate float

End learning rate for the schedule.

0.0

Returns:

Type Description
Callable

A lr schedule function, step -> learning rate.

Source code in redco/deployers/deployer.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def get_lr_schedule_fn(self,
                       train_size,
                       per_device_batch_size,
                       n_epochs,
                       learning_rate,
                       schedule_type='linear',
                       warmup_ratio=0.,
                       warmup_steps=None,
                       init_learning_rate=0.,
                       end_learning_rate=0.):
    """Creates a learning rate schedule function.

    Args:
        train_size (int): Number of training examples per epoch.
        per_device_batch_size (int): Batch size per device.
        n_epochs (int): Number of epochs.
        learning_rate (float): Peak learning rate.
        schedule_type (str): Type of lr schedule, "linear" or "cosine".
        warmup_ratio (float): Ratio of lr warmup.
        warmup_steps (int): Number of warmup steps.
        init_learning_rate (float): Initial learning rate before warmup.
        end_learning_rate (float): End learning rate for the schedule.

    Returns:
        (Callable): A lr schedule function, step -> learning rate.
    """
    _, global_micro_batch_size = self.get_local_global_micro_batch_size(
        per_device_batch_size=per_device_batch_size)
    total_train_steps = n_epochs * (train_size // global_micro_batch_size)

    if warmup_steps is None:
        warmup_steps = int(total_train_steps * warmup_ratio)

    return get_lr_schedule_fn(
        schedule_type=schedule_type,
        total_train_steps=total_train_steps,
        warmup_steps=warmup_steps,
        init_learning_rate=init_learning_rate,
        learning_rate=learning_rate,
        end_learning_rate=end_learning_rate)

get_model_input_batches(examples, per_device_batch_size, collate_fn, shuffle, shuffle_rng, desc, is_train=False, accumulate_grad_batches=None)

Prepares model input batches from examples.

Parameters:

Name Type Description Default
examples list

List of input examples.

required
per_device_batch_size int

Batch size per device.

required
collate_fn Callable

Function to collate the examples.

required
shuffle bool

Whether to shuffle the examples.

required
shuffle_rng `jax.numpy.Array`

RNG for randomness of shuffling.

required
desc str

Description in the progress bar.

required
is_train bool

Whether the data is for training.

False
accumulate_grad_batches int

gradient accumulation batches.

None

Returns:

Type Description
generator

A python generator of batched model inputs.

Source code in redco/deployers/deployer.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def get_model_input_batches(self,
                            examples,
                            per_device_batch_size,
                            collate_fn,
                            shuffle,
                            shuffle_rng,
                            desc,
                            is_train=False,
                            accumulate_grad_batches=None):
    """Prepares model input batches from examples.

    Args:
        examples (list): List of input examples.
        per_device_batch_size (int): Batch size per device.
        collate_fn (Callable): Function to collate the examples.
        shuffle (bool): Whether to shuffle the examples.
        shuffle_rng (`jax.numpy.Array`): RNG for randomness of shuffling.
        desc (str): Description in the progress bar.
        is_train (bool): Whether the data is for training.
        accumulate_grad_batches (int): gradient accumulation batches.

    Returns:
        (generator): A python generator of batched model inputs.
    """
    local_micro_batch_size, global_micro_batch_size = \
        self.get_local_global_micro_batch_size(
            per_device_batch_size=per_device_batch_size)

    examples = get_host_examples(
        examples=examples,
        global_micro_batch_size=global_micro_batch_size,
        shuffle=shuffle,
        shuffle_rng=shuffle_rng,
        mesh=self._mesh)

    if not is_train:
        desc = f'{desc} (global_batch_size = {global_micro_batch_size})'
    elif accumulate_grad_batches is None:
        desc = \
            f'{desc} (global_micro_batch_size = {global_micro_batch_size})'
    else:
        desc = (f'{desc} ('
                f'global_micro_batch_size = {global_micro_batch_size}, '
                f'accumulate_grad_batches = {accumulate_grad_batches})')

    return get_data_batches(
        examples=examples,
        batch_size=local_micro_batch_size,
        collate_fn=collate_fn,
        mesh=self._mesh,
        desc=desc,
        verbose=self._verbose)

get_opt_state_spec(params_shape_or_params, params_spec, optimizer)

Get optimizer state specs

Source code in redco/deployers/deployer.py
251
252
253
254
255
256
257
def get_opt_state_spec(
        self, params_shape_or_params, params_spec, optimizer):
    """Get optimizer state specs"""
    return get_opt_state_spec(
        params_shape_or_params=params_shape_or_params,
        params_spec=params_spec,
        optimizer=optimizer)

get_params_spec(params_shape_or_params, params_sharding_rules)

Generates parameter specs based on sharding rules.

Source code in redco/deployers/deployer.py
245
246
247
248
249
def get_params_spec(self, params_shape_or_params, params_sharding_rules):
    """Generates parameter specs based on sharding rules."""
    return get_params_spec(
        params_shape_or_params=params_shape_or_params,
        params_sharding_rules=params_sharding_rules)

get_sharding_rules(params_shape_or_params)

Get sharding rules based on the parameter shapes.

Source code in redco/deployers/deployer.py
235
236
237
238
239
240
241
242
243
def get_sharding_rules(self, params_shape_or_params):
    """Get sharding rules based on the parameter shapes."""
    if self._mesh is None:
        return None
    else:
        sharding_rules = get_sharding_rules(
            params_shape_or_params=params_shape_or_params,
            n_model_shards=self._mesh.shape['mp'])
        return sharding_rules

load_ckpt(ckpt_dir, params_sharding_rules=None, optimizer=None, float_dtype=None, load_params=True, load_opt_state=True, update_rng=False)

Loads a checkpoint from the specified directory.

Parameters:

Name Type Description Default
ckpt_dir str

Directory of the checkpoint.

required
params_sharding_rules PyTree

Sharding rules for parameters.

None
optimizer optax optimizer

Optimizer for loading opt_state.

None
float_dtype `jax.numpy.dtype`

Dtype for floating point numbers.

None
load_params bool

Whether to load the parameters.

True
load_opt_state bool

Whether to load the optimizer state.

True
update_rng bool

if updating the random state of the deployer.

False

Returns:

Type Description
tuple

A tuple with the loaded checkpoint (in a dict with "params" and "opt_state") and additional information (in a dict, usually including "steps", "epoch_idx", and "rng").

Source code in redco/deployers/deployer.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
def load_ckpt(self,
              ckpt_dir,
              params_sharding_rules=None,
              optimizer=None,
              float_dtype=None,
              load_params=True,
              load_opt_state=True,
              update_rng=False):
    """Loads a checkpoint from the specified directory.

    Args:
        ckpt_dir (str): Directory of the checkpoint.
        params_sharding_rules (PyTree): Sharding rules for parameters.
        optimizer (optax optimizer): Optimizer for loading opt_state.
        float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.
        load_params (bool): Whether to load the parameters.
        load_opt_state (bool): Whether to load the optimizer state.
        update_rng (bool): if updating the random state of the deployer.

    Returns:
        (tuple): A tuple with the loaded checkpoint (in a dict with
            `"params"` and `"opt_state"`) and additional information (in a
            dict, usually including `"steps"`, `"epoch_idx"`, and `"rng"`).
    """
    ckpt_dir = os.path.abspath(ckpt_dir)
    self.log_info(f'Loading ckpt from {ckpt_dir} ...')

    params_shape = self.load_params_shape(ckpt_dir=ckpt_dir)

    specs = {}
    if self._mesh is not None:
        if params_sharding_rules is None:
            params_sharding_rules = self.get_sharding_rules(
                params_shape_or_params=params_shape)

        specs['params'] = self.get_params_spec(
            params_shape_or_params=params_shape,
            params_sharding_rules=params_sharding_rules)
        if optimizer is not None:
            specs['opt_state'] = self.get_opt_state_spec(
                params_shape_or_params=params_shape,
                params_spec=specs['params'],
                optimizer=optimizer)

    ckpt, info = load_ckpt(
        ckpt_dir=ckpt_dir,
        checkpointer=self._checkpointer,
        params_shape_or_params=params_shape,
        optimizer=optimizer,
        float_dtype=float_dtype,
        mesh=self._mesh,
        specs=specs,
        load_params=load_params,
        load_opt_state=load_opt_state)

    for key, value in info.items():
        if not update_rng and key == 'rng':
            continue
        self.log_info(f'{ckpt_dir}::{key} = {value}')

    if update_rng:
        self._rng = info['rng']
        self.log_info(f'rng updated to {self._rng} (by {ckpt_dir})')

    return ckpt, info

load_last_ckpt(optimizer=None, params_sharding_rules=None, float_dtype=None, load_params=True, load_opt_state=True, update_rng=True)

Loads the last checkpoint from the work directory (self.workdir). See load_ckpt() for the explanation of arguments.

Source code in redco/deployers/deployer.py
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def load_last_ckpt(self,
                   optimizer=None,
                   params_sharding_rules=None,
                   float_dtype=None,
                   load_params=True,
                   load_opt_state=True,
                   update_rng=True):
    """Loads the last checkpoint from the work directory (self.workdir).
    See load_ckpt() for the explanation of arguments.
    """
    try:
        last_ckpt_name = open(
            f'{self._workdir}/ckpts/last_ckpt.txt').read().strip()
    except:
        self.log_info(
            f'{self._workdir}/ckpts/last_ckpt.txt not found. '
            f'no ckpt loaded.')
        return None, None

    return self.load_ckpt(
        ckpt_dir=f'{self._workdir}/ckpts/{last_ckpt_name}',
        optimizer=optimizer,
        float_dtype=float_dtype,
        params_sharding_rules=params_sharding_rules,
        load_params=load_params,
        load_opt_state=load_opt_state,
        update_rng=update_rng)

load_params_shape(ckpt_dir)

Loads the shape of the parameters from a checkpoint.

Source code in redco/deployers/deployer.py
342
343
344
def load_params_shape(self, ckpt_dir):
    """Loads the shape of the parameters from a checkpoint."""
    return load_params_shape(ckpt_dir=ckpt_dir)

log_info(info, title=None, step=None)

Logs a messages

Source code in redco/deployers/deployer.py
289
290
291
292
293
294
295
296
def log_info(self, info, title=None, step=None):
    """Logs a messages"""
    log_info(
        info=info,
        title=title,
        logger=self._logger,
        summary_writer=self._summary_writer,
        step=step)

log_metrics(metrics, step)

Logs metrics to TensorBoard and Weights and Biases (wandb).

Source code in redco/deployers/deployer.py
298
299
300
301
302
303
304
305
def log_metrics(self, metrics, step):
    """Logs metrics to TensorBoard and Weights and Biases (wandb)."""
    if self._summary_writer is not None:
        for metric_name, value in metrics.items():
            self._summary_writer.scalar(metric_name, value, step=step)

    if self._wandb_log_fn is not None:
        self._wandb_log_fn(metrics, step)

run_model_step(step_fn, input_args)

Executes a model step function with the provided inputs.

Source code in redco/deployers/deployer.py
265
266
267
268
269
270
271
def run_model_step(self, step_fn, input_args):
    """Executes a model step function with the provided inputs."""
    if self._mesh is None:
        return step_fn(*input_args)
    else:
        with self._mesh:
            return step_fn(*input_args)

save_ckpt(ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs)

Saves a checkpoint to the specified directory.

Parameters:

Name Type Description Default
ckpt_dir str

Directory to save the checkpoint.

required
params dict

Model parameters.

required
opt_state dict

Optimizer state.

None
float_dtype `jax.numpy.dtype`

Dtype for floating point numbers.

None
**kwargs dict

Additional information to be saved into info.json, e.g., current training step, epoch index, etc.

{}
Source code in redco/deployers/deployer.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def save_ckpt(
        self, ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs):
    """Saves a checkpoint to the specified directory.

    Args:
        ckpt_dir (str): Directory to save the checkpoint.
        params (dict): Model parameters.
        opt_state (dict): Optimizer state.
        float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.
        **kwargs (dict): Additional information to be saved into
            info.json, e.g., current training step, epoch index, etc.
    """
    ckpt_dir = os.path.abspath(ckpt_dir)
    self.log_info(f'Saving ckpt to {ckpt_dir} ...')
    save_ckpt(
        ckpt_dir=ckpt_dir,
        checkpointer=self._checkpointer,
        params=params,
        opt_state=opt_state,
        float_dtype=float_dtype,
        rng=self._rng,
        **kwargs)
    self.log_info(f'Ckpt saved into {ckpt_dir}')

save_outputs(outputs, desc, step)

Saves model outputs to workdir.

Source code in redco/deployers/deployer.py
307
308
309
310
311
312
313
314
315
316
def save_outputs(self, outputs, desc, step):
    """Saves model outputs to workdir."""
    if self._workdir is not None and jax.process_index() == 0:
        save_outputs(
            workdir=self._workdir,
            outputs=outputs,
            desc=desc,
            step=step,
            logger=self._logger,
            summary_writer=self._summary_writer)

shard_params(params, params_spec, desc='params')

Distributes parameters to all devices based on the provided specs.

Source code in redco/deployers/deployer.py
259
260
261
262
263
def shard_params(self, params, params_spec, desc='params'):
    """Distributes parameters to all devices based on the provided specs."""
    self.log_info(info=f'Sharding {desc} ...')
    return shard_params(
        mesh=self._mesh, params=params, params_spec=params_spec)