|
|
--- |
|
|
pipeline_tag: audio-classification |
|
|
tags: |
|
|
- audio |
|
|
- music |
|
|
--- |
|
|
# MERT |
|
|
|
|
|
MERT (Acoustic Music Understanding Model with Large-Scale Self-supervised Training) incorporates teacher models to provide pseudo labels in the masked language modelling (MLM) style acoustic pre-training. |
|
|
|
|
|
The pre-trained weights of MERT came from [m-a-p/MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M). In this repository, we registered MERT for [AutoModelForAudioClassification](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForAudioClassification) auto class. |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import numpy as np |
|
|
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification |
|
|
|
|
|
# Some configurations |
|
|
model_id = 'yangwang825/mert-base' |
|
|
batch_size = 4 |
|
|
num_classes = 10 |
|
|
max_duration = 1.0 |
|
|
|
|
|
# Initialise the extractor and model |
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True |
|
|
) |
|
|
mert = AutoModelForAudioClassification.from_pretrained( |
|
|
model_id, |
|
|
num_labels=num_classes, |
|
|
ignore_mismatched_sizes=True, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
# Simulate a list of waveforms (e.g. four audio clips) |
|
|
audio_arrays = [ |
|
|
np.random.rand(16000, ), |
|
|
np.random.rand(24000, ), |
|
|
np.random.rand(22050, ), |
|
|
np.random.rand(44100, ) |
|
|
] |
|
|
inputs = feature_extractor( |
|
|
audio_arrays, # List of waveforms in numpy array format |
|
|
sampling_rate=feature_extractor.sampling_rate, |
|
|
max_length=int(feature_extractor.sampling_rate * max_duration), |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_tensors='pt' |
|
|
) |
|
|
# The shape of `input_values` is (batch_size, sample_rate * max_duration) |
|
|
input_values = inputs['input_values'] |
|
|
outputs = mert(**inputs) |
|
|
# The shape of `logits` is (batch_size, num_classes) |
|
|
logits = outputs['logits'] |
|
|
``` |