Drum Transcription¶
Drum transcription module.
Utilities for transcribing drum percussions in the music.
Feature Storage Format¶
Processed feature will be stored in .hdf
file format, one file per piece.
Columns in the file are:
feature: CQT feature.
label: Merged drum label set, with a total of 13 classes.
label_128: Complete drum label set.
mini_beat_arr: The tracked mini-beat array of the clip.
Example¶
>>> import h5py
>>> hdf_ref = h5py.File("ytd_audio_00001_TRBSAIC128E0793CCE.hdf", "r")
>>> hdf_ref.keys()
<KeysViewHDF5 ['mini_beat_arr', 'feature', 'label', 'label_128']>
>>> feature = hdf_ref["feature"][:]
>>> print(feature.shape)
(2299, 120, 120)
>>> hdf_ref.close()
References¶
The relative publication can be found in [1]
- 1
I-Chieh Wei, Chih-Wei Wu, Li Su. “Improving Automatic Drum Transcription Using Large-Scale Audio-to-MIDI Aligned Data” (in submission)
App¶
- class omnizart.drum.app.DrumTranscription¶
Bases:
omnizart.base.BaseTranscription
Application class for drum transcriptions.
Methods
generate_feature
(dataset_path[, ...])Extract the feature of the whole dataset.
train
(feature_folder[, model_name, ...])Model training.
transcribe
(input_audio[, model_path, output])Transcribe drum in the audio.
- generate_feature(dataset_path, drum_settings=None, num_threads=3)¶
Extract the feature of the whole dataset.
Currently only supports Pop dataset. To train the model, you have to prepare the training data first, then process it into feature representations. After downloading the dataset, use this function to do the pre-processing and transform the raw data into features.
To specify the output path, modify the attribute
music_settings.dataset.feature_save_path
to the value you want. It will default to the folder under where the dataset stored, generating two folders:train_feature
andtest_feature
.- Parameters
- dataset_path: Path
Path to the downloaded dataset.
- drum_settings: DrumSettings
The configuration instance that holds all relative settings for the life-cycle of building a model.
- num_threads:
Number of threads for parallel extracting the features.
See also
omnizart.constants.datasets.PopStructure
The only supported dataset for drum transcription. Records the train/test partition according to the folder.
- train(feature_folder, model_name=None, input_model_path=None, drum_settings=None)¶
Model training.
Train a new model or continue to train on a previously trained model.
- Parameters
- feature_folder: Path
Path to the folder containing generated feature.
- model_name: str
The name for storing the trained model. If not given, will default to the current timesamp.
- input_model_path: Path
Continue to train on the pre-trained model by specifying the path.
- drum_settings: DrumSettings
The configuration instance that holds all relative settings for the life-cycle of building a model.
- transcribe(input_audio, model_path=None, output='./')¶
Transcribe drum in the audio.
This function transcribes drum activations in the music. Currently the model predicts 13 classes of different drum sets, and 3 of them will be written to the MIDI file.
- Parameters
- input_audio: Path
Path to the raw audio file (.wav).
- model_path: Path
Path to the trained model or the supported transcription mode.
- output: Path (optional)
Path for writing out the transcribed MIDI file. Default to the current path.
- Returns
- midi: pretty_midi.PrettyMIDI
The transcribed drum notes.
See also
omnizart.cli.drum.transcribe
CLI entry point of this function.
Dataset¶
- class omnizart.drum.app.PopDatasetLoader(mini_beat_per_seg=4, feature_folder=None, feature_files=None, num_samples=100, slice_hop=1)¶
Bases:
omnizart.base.BaseDatasetLoader
Pop dataset loader for training drum model.
Inference¶
- omnizart.drum.inference.get_3inst_ary(inst_13_ary_in)¶
- omnizart.drum.inference.inference(pred, m_beat_arr, bass_drum_th=0.85, snare_th=1.2, hihat_th=0.17)¶
Labels¶
- omnizart.drum.labels.extract_label(label_path, m_beat_arr)¶
Extract drum label notes.
Process ground-truth midi into numpy array representation.
- Parameters
- label_path: Path
Path to the midi file.
- m_beat_arr:
Extracted mini-beat array of the coressponding audio piece.
- Returns
- drum_track_ary: numpy.ndarray
The extracted label in numpy array. Should have a total of 128 classes of drum notes.
See also
omnizart.feature.beat_for_drum.extract_mini_beat_from_audio_path
The function for extracting mini-beat array from the given audio path.
- omnizart.drum.labels.extract_label_13_inst(label_path, m_beat_arr)¶
Extract 13 types of drum label notes.
Process the MIDI drum notes into numpy array and concludes them into 13 different sub-classes of drum notes.
- Parameters
- label_path: Path
Path to the midi file.
- m_beat_arr:
Extracted mini-beat array of the coressponding audio piece.
- Returns
- drum_track_ary: numpy.ndarray
The extracted label in numpy array.
See also
omnizart.drum.labels.extract_label
Complete drum label extraction with 128 output classes.
omnizart.feature.beat_for_drum.extract_mini_beat_from_audio_path
The function for extracting mini-beat array from the given audio path.
Prediction¶
- omnizart.drum.prediction.create_batches(feature, mini_beat_per_seg, b_size=6)¶
Create a 4D input for model prediction.
- Parameters
- feature: 3D numpy array
Should be in shape [mini_beat_pos x time x freq].
- mini_beat_per_seg: int
Number of mini beats in one segment (a beat).
- b_size: int
Output batch size.
- Returns
- batch_feature: 5D numpy array
Dimensions are [batches x b_size x time x freq x mini_beat_per_seg].
- pad_size: int
The additional padded size at the end of the batch.
- omnizart.drum.prediction.merge_batches(batch_pred)¶
Reverse process of create_batches.
Merges a 5D batched-prediction into 2D output.
- omnizart.drum.prediction.predict(patch_cqt_feature, model, mini_beat_per_seg, batch_size=32)¶
Settings¶
Below are the default settings for building the drum model. It will be loaded
by the class omnizart.setting_loaders.DrumSettings
. The name of the
attributes will be converted to snake-case (e.g., HopSize -> hop_size). There
is also a path transformation process when applying the settings into the
DrumSettings
instance. For example, if you want to access the attribute
BatchSize
defined in the yaml path General/Training/Settings/BatchSize,
the corresponding attribute will be DrumSettings.training.batch_size.
The level of /Settings is removed among all fields.
General:
TranscriptionMode:
Description: Mode of transcription by executing the `omnizart drum transcribe` command.
Type: String
Value: Keras
CheckpointPath:
Description: Path to the pre-trained models.
Type: Map
SubType: [String, String]
Value:
Keras: checkpoints/drum/drum_keras
Feature:
Description: Default settings of feature extraction for drum transcription.
Settings:
SamplingRate:
Description: Adjust input sampling rate to this value.
Type: Integer
Value: 44100
PaddingSeconds:
Description: Padding length to the begin and the end of the raw audio data.
Type: Float
Value: 1.0
LowestNote:
Description: Lowest MIDI note number to be considered.
Type: Integer
Value: 16
NumberOfNotes:
Description: Number of total notes to extract.
Type: Integer
Value: 120
HopSize:
Description: Hop size for computing CQT feature.
Type: Integer
Value: 256
MiniBeatPerBar:
Description: Number of mini beats in a single 4/4 measure.
Type: Integer
Value: 32
MiniBeatPerSegment:
Description: Number of mini beats in a single 4/4 measure.
Type: Integer
Value: 4
Dataset:
Description: Settings of datasets.
Settings:
SavePath:
Description: Path for storing the downloaded datasets.
Type: String
Value: ./
FeatureSavePath:
Description: Path for storing the extracted feature. Default to the path under the dataset folder.
Type: String
Value: +
Model:
Description: Default settings of training / testing the model.
Settings:
SavePrefix:
Description: Prefix of the trained model's name to be saved.
Type: String
Value: drum
SavePath:
Description: Path to save the trained model.
Type: String
Value: ./checkpoints/drum
Inference:
Description: Default settings when infereing notes.
Settings:
BassDrumTh:
Description: Threshold for the bass drum.
Type: Float
Value: 0.85
SnareTh:
Description: Threshold for the snare.
Type: Float
Value: 1.2
HihatTh:
Description: Threshold for the hihat.
Type: Float
Value: 0.17
Training:
Description: Hyper parameters for training
Settings:
Epoch:
Description: Maximum number of epochs for training.
Type: Integer
Value: 50
Steps:
Description: Number of training steps for each epoch.
Type: Integer
Value: 1000
ValSteps:
Description: Number of validation steps after each training epoch.
Type: Integer
Value: 100
BatchSize:
Description: Batch size of each training step.
Type: Integer
Value: 32
ValBatchSize:
Description: Batch size of each validation step.
Type: Integer
Value: 32
EarlyStop:
Description: Terminate the training if the validation performance doesn't imrove after n epochs.
Type: Integer
Value: 6
InitLearningRate:
Descriptoin: Initial learning rate.
Type: Float
Value: 0.00002
ResBlockNum:
Description: Number of residual blocks.
Type: Integer
Value: 3