Drum Transcription

Drum transcription module.

Utilities for transcribing drum percussions in the music.

Feature Storage Format

Processed feature will be stored in .hdf file format, one file per piece.

Columns in the file are:

  • feature: CQT feature.

  • label: Merged drum label set, with a total of 13 classes.

  • label_128: Complete drum label set.

  • mini_beat_arr: The tracked mini-beat array of the clip.

Example

>>> import h5py
>>> hdf_ref = h5py.File("ytd_audio_00001_TRBSAIC128E0793CCE.hdf", "r")
>>> hdf_ref.keys()
<KeysViewHDF5 ['mini_beat_arr', 'feature', 'label', 'label_128']>
>>> feature = hdf_ref["feature"][:]
>>> print(feature.shape)
(2299, 120, 120)
>>> hdf_ref.close()

References

The relative publication can be found in [1]

1

I-Chieh Wei, Chih-Wei Wu, Li Su. “Improving Automatic Drum Transcription Using Large-Scale Audio-to-MIDI Aligned Data” (in submission)

App

class omnizart.drum.app.DrumTranscription

Bases: omnizart.base.BaseTranscription

Application class for drum transcriptions.

Methods

generate_feature(dataset_path[, ...])

Extract the feature of the whole dataset.

train(feature_folder[, model_name, ...])

Model training.

transcribe(input_audio[, model_path, output])

Transcribe drum in the audio.

generate_feature(dataset_path, drum_settings=None, num_threads=3)

Extract the feature of the whole dataset.

Currently only supports Pop dataset. To train the model, you have to prepare the training data first, then process it into feature representations. After downloading the dataset, use this function to do the pre-processing and transform the raw data into features.

To specify the output path, modify the attribute music_settings.dataset.feature_save_path to the value you want. It will default to the folder under where the dataset stored, generating two folders: train_feature and test_feature.

Parameters
dataset_path: Path

Path to the downloaded dataset.

drum_settings: DrumSettings

The configuration instance that holds all relative settings for the life-cycle of building a model.

num_threads:

Number of threads for parallel extracting the features.

See also

omnizart.constants.datasets.PopStructure

The only supported dataset for drum transcription. Records the train/test partition according to the folder.

train(feature_folder, model_name=None, input_model_path=None, drum_settings=None)

Model training.

Train a new model or continue to train on a previously trained model.

Parameters
feature_folder: Path

Path to the folder containing generated feature.

model_name: str

The name for storing the trained model. If not given, will default to the current timesamp.

input_model_path: Path

Continue to train on the pre-trained model by specifying the path.

drum_settings: DrumSettings

The configuration instance that holds all relative settings for the life-cycle of building a model.

transcribe(input_audio, model_path=None, output='./')

Transcribe drum in the audio.

This function transcribes drum activations in the music. Currently the model predicts 13 classes of different drum sets, and 3 of them will be written to the MIDI file.

Parameters
input_audio: Path

Path to the raw audio file (.wav).

model_path: Path

Path to the trained model or the supported transcription mode.

output: Path (optional)

Path for writing out the transcribed MIDI file. Default to the current path.

Returns
midi: pretty_midi.PrettyMIDI

The transcribed drum notes.

See also

omnizart.cli.drum.transcribe

CLI entry point of this function.

Dataset

class omnizart.drum.app.PopDatasetLoader(mini_beat_per_seg=4, feature_folder=None, feature_files=None, num_samples=100, slice_hop=1)

Bases: omnizart.base.BaseDatasetLoader

Pop dataset loader for training drum model.

Inference

omnizart.drum.inference.get_3inst_ary(inst_13_ary_in)
omnizart.drum.inference.inference(pred, m_beat_arr, bass_drum_th=0.85, snare_th=1.2, hihat_th=0.17)

Labels

omnizart.drum.labels.extract_label(label_path, m_beat_arr)

Extract drum label notes.

Process ground-truth midi into numpy array representation.

Parameters
label_path: Path

Path to the midi file.

m_beat_arr:

Extracted mini-beat array of the coressponding audio piece.

Returns
drum_track_ary: numpy.ndarray

The extracted label in numpy array. Should have a total of 128 classes of drum notes.

See also

omnizart.feature.beat_for_drum.extract_mini_beat_from_audio_path

The function for extracting mini-beat array from the given audio path.

omnizart.drum.labels.extract_label_13_inst(label_path, m_beat_arr)

Extract 13 types of drum label notes.

Process the MIDI drum notes into numpy array and concludes them into 13 different sub-classes of drum notes.

Parameters
label_path: Path

Path to the midi file.

m_beat_arr:

Extracted mini-beat array of the coressponding audio piece.

Returns
drum_track_ary: numpy.ndarray

The extracted label in numpy array.

See also

omnizart.drum.labels.extract_label

Complete drum label extraction with 128 output classes.

omnizart.feature.beat_for_drum.extract_mini_beat_from_audio_path

The function for extracting mini-beat array from the given audio path.

Prediction

omnizart.drum.prediction.create_batches(feature, mini_beat_per_seg, b_size=6)

Create a 4D input for model prediction.

Parameters
feature: 3D numpy array

Should be in shape [mini_beat_pos x time x freq].

mini_beat_per_seg: int

Number of mini beats in one segment (a beat).

b_size: int

Output batch size.

Returns
batch_feature: 5D numpy array

Dimensions are [batches x b_size x time x freq x mini_beat_per_seg].

pad_size: int

The additional padded size at the end of the batch.

omnizart.drum.prediction.merge_batches(batch_pred)

Reverse process of create_batches.

Merges a 5D batched-prediction into 2D output.

omnizart.drum.prediction.predict(patch_cqt_feature, model, mini_beat_per_seg, batch_size=32)

Settings

Below are the default settings for building the drum model. It will be loaded by the class omnizart.setting_loaders.DrumSettings. The name of the attributes will be converted to snake-case (e.g., HopSize -> hop_size). There is also a path transformation process when applying the settings into the DrumSettings instance. For example, if you want to access the attribute BatchSize defined in the yaml path General/Training/Settings/BatchSize, the corresponding attribute will be DrumSettings.training.batch_size. The level of /Settings is removed among all fields.

General:
    TranscriptionMode:
        Description: Mode of transcription by executing the `omnizart drum transcribe` command.
        Type: String
        Value: Keras
    CheckpointPath:
        Description: Path to the pre-trained models.
        Type: Map
        SubType: [String, String]
        Value:
            Keras: checkpoints/drum/drum_keras
    Feature:
        Description: Default settings of feature extraction for drum transcription.
        Settings:
            SamplingRate:
                Description: Adjust input sampling rate to this value.
                Type: Integer
                Value: 44100
            PaddingSeconds:
                Description: Padding length to the begin and the end of the raw audio data.
                Type: Float
                Value: 1.0
            LowestNote:
                Description: Lowest MIDI note number to be considered.
                Type: Integer
                Value: 16
            NumberOfNotes:
                Description: Number of total notes to extract.
                Type: Integer
                Value: 120
            HopSize:
                Description: Hop size for computing CQT feature.
                Type: Integer
                Value: 256
            MiniBeatPerBar:
                Description: Number of mini beats in a single 4/4 measure.
                Type: Integer
                Value: 32
            MiniBeatPerSegment:
                Description: Number of mini beats in a single 4/4 measure.
                Type: Integer
                Value: 4
    Dataset:
        Description: Settings of datasets.
        Settings:
            SavePath:
                Description: Path for storing the downloaded datasets.
                Type: String
                Value: ./
            FeatureSavePath:
                Description: Path for storing the extracted feature. Default to the path under the dataset folder.
                Type: String
                Value: +
    Model:
        Description: Default settings of training / testing the model.
        Settings:
            SavePrefix:
                Description: Prefix of the trained model's name to be saved.
                Type: String
                Value: drum
            SavePath:
                Description: Path to save the trained model.
                Type: String
                Value: ./checkpoints/drum
    Inference:
        Description: Default settings when infereing notes.
        Settings:
            BassDrumTh:
                Description: Threshold for the bass drum.
                Type: Float
                Value: 0.85
            SnareTh:
                Description: Threshold for the snare.
                Type: Float
                Value: 1.2
            HihatTh:
                Description: Threshold for the hihat.
                Type: Float
                Value: 0.17
    Training:
        Description: Hyper parameters for training
        Settings:
            Epoch:
                Description: Maximum number of epochs for training.
                Type: Integer
                Value: 50
            Steps:
                Description: Number of training steps for each epoch.
                Type: Integer
                Value: 1000
            ValSteps:
                Description: Number of validation steps after each training epoch.
                Type: Integer
                Value: 100
            BatchSize:
                Description: Batch size of each training step.
                Type: Integer
                Value: 32
            ValBatchSize:
                Description: Batch size of each validation step.
                Type: Integer
                Value: 32
            EarlyStop:
                Description: Terminate the training if the validation performance doesn't imrove after n epochs.
                Type: Integer
                Value: 6
            InitLearningRate:
                Descriptoin: Initial learning rate.
                Type: Float
                Value: 0.00002
            ResBlockNum:
                Description: Number of residual blocks.
                Type: Integer
                Value: 3