Beat Transcription¶
MIDI domain beat tracking.
Track beats and downbeat in symbolic domain. Outputs the predicted beat positions in seconds. Re-implementation of the work [1] with tensorflow 2.3.0.
Feature Storage Format¶
Processed feature will be stored in .hdf
format, one file per piece.
Columns in the file are:
feature: Piano roll like representation with mixed information.
label:
References¶
App¶
- class omnizart.beat.app.BeatTranscription(conf_path=None)¶
Bases:
omnizart.base.BaseTranscription
Application class for beat tracking in MIDI domain.
Methods
generate_feature
(dataset_path[, ...])Extract the feature from the given dataset.
train
(feature_folder[, model_name, ...])Model training.
transcribe
(input_audio[, model_path, output])Transcribe beat positions in the given MIDI.
- generate_feature(dataset_path, beat_settings=None, num_threads=8)¶
Extract the feature from the given dataset.
To train the model, the first step is to pre-process the data into feature representations. After downloading the dataset, use this function to generate the feature by giving the path of the stored dataset.
To specify the output path, modify the attribute
beat_settings.dataset.feature_save_path
. It defaults to the folder under where the dataset stored, generating two folders:train_feature
andtest_feature
.- Parameters
- dataset_path: Path
Path to the downloaded dataset.
- beat_settings: BeatSettings
The configuration instance that holds all relative settings for the life-cycle of building a model.
- num_threads:
Number of threads for parallel extraction the feature.
- train(feature_folder, model_name=None, input_model_path=None, beat_settings=None)¶
Model training.
Train the model from scratch or continue training given a model checkpoint.
- Parameters
- feature_folder: Path
Path to the generated feature.
- model_name: str
The name of the trained model. If not given, will default to the current timestamp.
- input_model_path: Path
Specify the path to the model checkpoint in order to fine-tune the model.
- beat_settings: BeatSettings
The configuration that holds all relative settings for the life-cycle of model building.
- transcribe(input_audio, model_path=None, output='./')¶
Transcribe beat positions in the given MIDI.
Tracks the beat in symbolic domain. Outputs three files if the output path is given: <filename>.mid, <filename>_beat.csv, and <filename>_down_beat.csv, where filename is the name of the input MIDI without extension. The *.csv files records the beat positions in seconds.
- Parameters
- input_audio: Path
Path to the MIDI file (.mid).
- model_path: Path
Path to the trained model or the supported transcription mode.
- output: Path (optional)
Path for writing out the transcribed MIDI file. Default to the current path.
- Returns
- midi: pretty_midi.PrettyMIDI
The transcribed beat positions. There are two types of beat: beat and down beat. Each are recorded in independent instrument track.
See also
omnizart.cli.beat.transcribe
CLI entry point of this function.
Dataset¶
- class omnizart.beat.app.BeatDatasetLoader(feature_folder=None, feature_files=None, num_samples=100, slice_hop=1, feat_col_name='feature')¶
Bases:
omnizart.base.BaseDatasetLoader
Data loader for training the model of
beat
.Each feature slice will have an overlap size of timesteps//2.
- Parameters
- feature_folder: Path
Path to the extracted feature files, including *.hdf and *.pickle pairs, which refers to feature and label files, respectively.
- feature_files: list[Path]
List of path of *.hdf feature files. Corresponding label files should also under the same folder.
- num_samples: int
Total number of samples to yield.
- timesteps: int
Time length of the feature.
- Yields
- feature:
Input features for model training.
- label:
Corresponding labels.
Inference¶
- omnizart.beat.inference.inference(pred, beat_th=0.5, down_beat_th=0.5, min_dist=0.3, t_unit=0.1)¶
Infers the beat and down beat positions from the raw prediction values.
- Parameters
- pred: 2D numpy array
The prediction of the model.
- beat_th: float
Threshold for beat channel.
- down_beat_th: float
Threshold for down beat channel.
- min_dist: float
Minimum distance between two beat positions in seconds.
- t_unit: float
Time unit of each frame in seconds.
- Returns
- midi: pretty_midi.PrettyMIDI
Inferred beat positions recorded as MIDI notes. Information of beat and down beat are recorded in two different instrument tracks.
Loss Functions¶
- omnizart.beat.app.weighted_binary_crossentropy(target, pred, down_beat_weight=5)¶
Wrap around binary crossentropy loss with weighting to different channels.
Features¶
- omnizart.beat.features.extract_feature(labels, t_unit=0.01)¶
Extract feature representation required by beat module.
- Parameters
- labels: list[Label]
List of
omnizart.base.Label
instances.- t_unit: float
Time unit of each frame of the output representation.
- Returns
- feature: 2D numpy array
A piano roll like representation. Please refer to the original paper for more details.
- omnizart.beat.features.extract_feature_from_midi(midi_path, t_unit=0.01)¶
Extract feature for beat module from MIDI file.
See also
omnizart.beat.features.extract_feature
The main feature extraction function of beat module.
- omnizart.beat.features.extract_musicnet_feature(csv_path, t_unit=0.01)¶
Extract feature for beat module from MusicNet label file.
See also
omnizart.beat.features.extract_feature
The main feature extraction function of beat module.
- omnizart.beat.features.extract_musicnet_label(csv_path, meter=4, t_unit=0.01, rounding=1, fade_out=15)¶
Label extraction function for MusicNet.
This function extracts the beat and down beat information given the symbolic representations of MusicNet.
- Parameters
- csv_path: Path
Path to the ground-truth file in CSV format.
- meter: int
Meter information of the piece. Currently it is default to the most common meter, which is 4. Since there is no meter information recorded in MusicNet, the meter value will always be 4 and apparently this is not always true.
- t_unit: int
Time unit of each frame in seconds.
- rounding: int
Round to position below decimal of start beat.
- fade_out: int
Used to augment the sparse positive label in a fade-out manner, reducing the value from 1 to 1/fade_out, totaling in length of <fade_out>.
Prediction¶
- omnizart.beat.prediction.STEP_SIZE_RATIO = 0.5¶
Step size for slicing the feature. Ratio to the timesteps of the model input feature.
- omnizart.beat.prediction.create_batches(feature, timesteps, batch_size=8)¶
Create a 4D output from the 2D feature for model prediciton.
Create overlapped input features, and collect feature slices into batches. The overlap size is 1/4 length to the timesteps.
- Parameters
- feature: 2D numpy array
The feature representation for the model.
- timesteps: int
Size of the input feature dimension.
- batch_size: int
Batch size.
- Returns
- batches: 4D numpy array
Batched feature slices with dimension: batches x batch_size x timesteps x feat.
- omnizart.beat.prediction.merge_batches(batch_pred)¶
Merge the batched predictions back to the 2D output.
- omnizart.beat.prediction.predict(feature, model, timesteps=1000, batch_size=64)¶
Predict on the given feature with the model.
- Parameters
- feature: 2D numpy array
Input feature of the model.
- model:
The pre-trained Tensorflow model.
- timesteps: int
Size of the input feature dimension.
- batch_size: int
Batch size for the model input.
- Returns
- pred: 2D numpy array
The predicted probabilities of beat and down beat positions.
Settings¶
Below are the default settings for building the beat model. It will be loaded
by the class omnizart.setting_loaders.BeatSettings
. The name of the
attributes will be converted to snake-case (e.g., HopSize -> hop_size). There
is also a path transformation process when applying the settings into the
BeatSettings
instance. For example, if you want to access the attribute
BatchSize
defined in the yaml path General/Training/Settings/BatchSize,
the corresponding attribute will be BeatSettings.training.batch_size.
The level of /Settings is removed among all fields.
General:
TranscriptionMode:
Description: Mode of transcription by executing the `omnizart beat transcribe` command.
Type: String
Value: BLSTM
CheckpointPath:
Description: Path to the pre-trained models.
Type: Map
SubType: [String, String]
Value:
BLSTM: checkpoints/beat/beat_blstm
Feature:
Description: Default settings of feature extraction for drum transcription.
Settings:
TimeUnit:
Description: Time unit of each frame in seconds.
Type: Float
Value: 0.01
Dataset:
Description: Settings of datasets.
Settings:
SavePath:
Description: Path for storing the downloaded datasets.
Type: String
Value: ./
FeatureSavePath:
Description: Path for storing the extracted feature. Default to the path under the dataset folder.
Type: String
Value: +
Model:
Description: Default settings of training / testing the model.
Settings:
SavePrefix:
Description: Prefix of the trained model's name to be saved.
Type: String
Value: beat
SavePath:
Description: Path to save the trained model.
Type: String
Value: ./checkpoints/beat
ModelType:
Description: One of 'blstm' or 'blstm_attn'.
Type: String
Value: blstm
Timesteps:
Description: Input length of the model.
Type: Integer
Value: 1000
LstmHiddenDim:
Description: Dimension of LSTM hidden layers.
Type: Integer
Value: 25
NumLstmLayers:
Description: Number of LSTM layers.
Type: Integer
Value: 2
AttnHiddenDim:
Description: Dimension of multi-head attention layers.
Type: Integer
Value: 256
Inference:
Description: Default settings when infering notes.
Settings:
BeatThreshold:
Description: Threshold that will be applied to clip the predicted beat values to either 0 or 1.
Type: Float
Value: 0.5
DownBeatThreshold:
Description: Same as above, but for down beat.
Type: Float
Value: 0.3
MinDistance:
Description: Minimum required distance between each note in seconds.
Type: Float
Value: 0.3
Training:
Description: Hyper parameters for training
Settings:
Epoch:
Description: Maximum number of epochs for training.
Type: Integer
Value: 10
Steps:
Description: Number of training steps for each epoch.
Type: Integer
Value: 1000
ValSteps:
Description: Number of validation steps after each training epoch.
Type: Integer
Value: 50
BatchSize:
Description: Batch size of each training step.
Type: Integer
Value: 64
ValBatchSize:
Description: Batch size of each validation step.
Type: Integer
Value: 64
EarlyStop:
Description: Terminate the training if the validation performance doesn't imrove after n epochs.
Type: Integer
Value: 7
InitLearningRate:
Descriptoin: Initial learning rate.
Type: Float
Value: 0.001
DownBeatWeight:
Description: Weighting of down beat loss. Beat loss is always set to one.
Type: Float
Value: 5