Training¶
Including logics of main training loop, progress visualization, and callback functions.
Training¶
- omnizart.train.PROGRESS_BAR_FORMAT = Format of the training progress bar¶
str(object=’’) -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to ‘strict’.
- omnizart.train.execute_callbacks(callbacks, func_name, **kwargs)¶
Execute callbacks at different training stage.
- omnizart.train.format_num(num, digit=4)¶
Formatting the float values as string.
- omnizart.train.gen_bar_postfix(history, targets=['loss', 'accuracy'], name_transform=['loss', 'acc'])¶
Generate string of metrics status to be appended to the end of the progress bar.
- Parameters
- history: dict
History records generated by
train_steps
.- targets: list[str]
List of metric’s names to be extracted as the postfix.
- name_transform: list[str]
The alias metric name that will be showed on the bar. Should be the same length, same order as
targets
.
- Returns
- postfix: str
The extracted metrics information.
- omnizart.train.get_train_val_feat_file_list(feature_folder, split=0.9)¶
- omnizart.train.train_epochs(model, train_dataset, validate_dataset=None, epochs=10, steps=100, val_steps=100, callbacks=None, **kwargs)¶
Logic of training loop.
The main loop of the training, with events-based life-cycle management that triggers different events for all callbacks. Event types are the same as the original tensorflow implementation.
Event types and their order:
<start training> | |-on_train_begin T| |-on_epoch_begin R| | A| L|-on_train_batch_begin I| O|-on_train_batch_end N| O| I| P|-on_test_batch_begin N| |-on_test_batch_end G| | | |-on_epoch_end |-on_train_end | <finished training>
- Parameters
- model:
Compiled tensorflow keras model.
- train_dataset:
The tf.data.Dataset instance for training.
- validate_dataset:
The tf.data.Dataset instance for validation. If not given, validation stage will be skipped.
- epochs: int
Number of maximum training epochs.
- steps: int
Number of training steps for each epoch. Should be the same as when initiating the dataset instance.
- val_steps: int
Number of validation steps for each epoch.Should be the same as when initiating the dataset instance.
- callbacks:
List of callback instances.
- Returns
- history: dict
Score history of each metrics during each epoch of both training and validation.
See also
omnizart.callbacks
Implementation and available callbacks for training.
- omnizart.train.train_steps(model, dataset, steps=None, bar_title=None, validate=False)¶
A single training epoch with multiple steps.
Customized training epoch compared to the built-in
.fit(...)
function of tensorflow keras model. The major difference is that the.fit()
requires the dataset to yield either (feature, target) or (feature, target, weight) pairs, which losses the flexibility of yielding different numbers of elements for each iteration. And thus we’d decide to implement our own training logic and relevant utilities same as provided in tensorflow like callbacks.- Parameters
- model:
Compiled tf.keras model.
- dataset:
The loaded tf.data.Dataset object that yields (feature, target) pairs at the first two elements, indicating that you can yields more than two elements for each iteration, but only the first two will be used for training.
- steps: int
Total number of steps that the dataset object will yield. This is used for visualizing the training progress.
- bar_title: str
Additional title to be printed at the start of the progress bar.
- validate: bool
Indicating whether it is now in validation stage or it is within training loop that should update the weights of the model.
- Returns
- history: dict
The history of scores for each metric during each epoch.
Callbacks¶
- class omnizart.callbacks.Callback(monitor=None)¶
Base class of all callback classes
Methods
on_epoch_begin
on_epoch_end
on_test_batch_begin
on_test_batch_end
on_train_batch_begin
on_train_batch_end
on_train_begin
on_train_end
- class omnizart.callbacks.EarlyStopping(patience=5, monitor='val_acc')¶
Early stop the training after no improvement on the monitor for a certain period.
- Parameters
- patience: int
Longeset period of epochs for waiting the target metrics showing improvement.
- monitor: str
Metric name for the observation.
Methods
on_epoch_end
on_train_begin
on_train_end
- class omnizart.callbacks.ModelCheckpoint(filepath, monitor='val_acc', save_best_only=False, save_weights_only=False)¶
Saving the model during training.
The newest checkpoint will override the original checkpoint during a single training period.
- Parameters
- filepath: Path
Path for saving the checkpoint.
- monitor: str
Metric name for the observation. No effect if save_bset_only is set to false.
- save_best_only: bool
Whether to save the model having the best performance on the metric only.
- save_weights_only: bool
Save the model’s weight only, without architecture.
Methods
on_epoch_end
on_train_begin
- class omnizart.callbacks.TFModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', save_freq='epoch', options=None, **kwargs)¶
Re-implementation of Tensorflow ModelCheckpoint.
Customize the behaviour of saving the checkpoints. When specify save_weights_only to ‘True’, save the weights only during training, and save the whole model including architecture using model.save() at the end of training.
This callback is mainly designed for saving customized models that is unable to use model.to_yaml() function.
Methods
on_train_end
(logs)Called at the end of training.
set_model
- on_train_end(logs)¶
Called at the end of training.
Subclasses should override for any actions to run.
- Args:
- logs: Dict. Currently the output of the last call to on_epoch_end()
is passed to this argument for this method but that may change in the future.