Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Epsilon617
		
	commited on
		
		
					Commit 
							
							·
						
						8c952bb
	
1
								Parent(s):
							
							826be26
								
add model for offline mode run
Browse files- MERT-v0-public/.gitattributes +34 -0
 - MERT-v0-public/README.md +115 -0
 - MERT-v0-public/__init__.py +0 -0
 - MERT-v0-public/__pycache__/__init__.cpython-310.pyc +0 -0
 - MERT-v0-public/__pycache__/configuration_MERT.cpython-310.pyc +0 -0
 - MERT-v0-public/__pycache__/modeling_MERT.cpython-310.pyc +0 -0
 - MERT-v0-public/config.json +84 -0
 - MERT-v0-public/configuration_MERT.py +131 -0
 - MERT-v0-public/modeling_MERT.py +409 -0
 - MERT-v0-public/preprocessor_config.json +9 -0
 - MERT-v0-public/pytorch_model.bin +3 -0
 - __pycache__/app.cpython-310.pyc +0 -0
 - app.py +12 -5
 
    	
        MERT-v0-public/.gitattributes
    ADDED
    
    | 
         @@ -0,0 +1,34 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         
     | 
| 2 | 
         
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         
     | 
| 3 | 
         
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         
     | 
| 4 | 
         
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         
     | 
| 5 | 
         
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         
     | 
| 6 | 
         
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 7 | 
         
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 8 | 
         
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         
     | 
| 9 | 
         
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         
     | 
| 10 | 
         
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 11 | 
         
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         
     | 
| 12 | 
         
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         
     | 
| 13 | 
         
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         
     | 
| 14 | 
         
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         
     | 
| 15 | 
         
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 16 | 
         
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         
     | 
| 17 | 
         
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         
     | 
| 18 | 
         
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         
     | 
| 19 | 
         
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         
     | 
| 20 | 
         
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         
     | 
| 21 | 
         
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 22 | 
         
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         
     | 
| 23 | 
         
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         
     | 
| 24 | 
         
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         
     | 
| 25 | 
         
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         
     | 
| 26 | 
         
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 27 | 
         
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 28 | 
         
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         
     | 
| 29 | 
         
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 30 | 
         
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         
     | 
| 31 | 
         
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 32 | 
         
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 33 | 
         
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        MERT-v0-public/README.md
    ADDED
    
    | 
         @@ -0,0 +1,115 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            license: mit
         
     | 
| 3 | 
         
            +
            inference: false
         
     | 
| 4 | 
         
            +
            tags:
         
     | 
| 5 | 
         
            +
            - music
         
     | 
| 6 | 
         
            +
            ---
         
     | 
| 7 | 
         
            +
            # Introduction to our series work
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            The development log of our Music Audio Pre-training (m-a-p) model family:
         
     | 
| 10 | 
         
            +
            - 17/03/2023: we release two advanced music understanding models, [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) and [MERT-v1-330M](https://huggingface.co/m-a-p/MERT-v1-330M) , trained with new paradigm and dataset. They outperform the previous models and can better generalize to more tasks.
         
     | 
| 11 | 
         
            +
            - 14/03/2023: we retrained the MERT-v0 model with open-source-only music dataset [MERT-v0-public](https://huggingface.co/m-a-p/MERT-v0-public)
         
     | 
| 12 | 
         
            +
            - 29/12/2022: a music understanding model [MERT-v0](https://huggingface.co/m-a-p/MERT-v0) trained with **MLM** paradigm, which performs better at downstream tasks.
         
     | 
| 13 | 
         
            +
            - 29/10/2022: a pre-trained MIR model [music2vec](https://huggingface.co/m-a-p/music2vec-v1) trained with **BYOL** paradigm.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            Here is a table for quick model pick-up:
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            | Name                                                         | Pre-train Paradigm | Training Data (hour) | Pre-train Context   (second) | Model Size | Transformer Layer-Dimension | Feature Rate | Sample Rate | Release Date |
         
     | 
| 20 | 
         
            +
            | ------------------------------------------------------------ | ------------------ | -------------------- | ---------------------------- | ---------- | --------------------------- | ------------ | ----------- | ------------ |
         
     | 
| 21 | 
         
            +
            | [MERT-v1-330M](https://huggingface.co/m-a-p/MERT-v1-330M)    | MLM                | 160K                 | 5                            | 330M       | 24-1024                     | 75 Hz        | 24K Hz      | 17/03/2023   |
         
     | 
| 22 | 
         
            +
            | [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M)      | MLM                | 20K                  | 5                            | 95M        | 12-768                      | 75 Hz        | 24K Hz      | 17/03/2023   |
         
     | 
| 23 | 
         
            +
            | [MERT-v0-public](https://huggingface.co/m-a-p/MERT-v0-public) | MLM                | 900                  | 5                            | 95M        | 12-768                      | 50 Hz        | 16K Hz      | 14/03/2023   |
         
     | 
| 24 | 
         
            +
            | [MERT-v0](https://huggingface.co/m-a-p/MERT-v0)              | MLM                | 1000                 | 5                            | 95 M       | 12-768                      | 50 Hz        | 16K Hz      | 29/12/2022   |
         
     | 
| 25 | 
         
            +
            | [music2vec-v1](https://huggingface.co/m-a-p/music2vec-v1)    | BYOL               | 1000                 | 30                           | 95 M       | 12-768                      | 50 Hz        | 16K Hz      | 30/10/2022   |
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            ## Explanation
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            The m-a-p models share the similar model architecture and the most distinguished difference is the paradigm in used pre-training. Other than that, there are several nuance technical configuration needs to know before using:
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            - **Model Size**: the number of parameters that would be loaded to memory. Please select the appropriate size fitting your hardware.
         
     | 
| 32 | 
         
            +
            - **Transformer Layer-Dimension**: The number of transformer layers and the corresponding feature dimensions can be outputted from our model. This is marked out because features extracted by **different layers could have various performance depending on tasks**.
         
     | 
| 33 | 
         
            +
            - **Feature Rate**: Given a 1-second audio input, the number of features output by the model.
         
     | 
| 34 | 
         
            +
            - **Sample Rate**: The frequency of audio that the model is trained with.
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            # Introduction to MERT-v0-public
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            **MERT-v0-public** is a completely unsupervised model trained on **completely non-comercial open-source** [Music4All](https://sites.google.com/view/contact4music4all) dataset and the part of [FMA_full](https://github.com/mdeff/fma) dataset that does not include tag "experimental".
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            The training settings and model usage of MERT-v0-public can be referred to the [MERT-v0 model](https://huggingface.co/m-a-p/MERT-v0).
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            Details are reported at the short article *Large-Scale Pretrained Model for Self-Supervised Music Audio Representation Learning*.
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            # Demo code
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            ```python
         
     | 
| 49 | 
         
            +
            from transformers import Wav2Vec2FeatureExtractor
         
     | 
| 50 | 
         
            +
            from transformers import AutoModel
         
     | 
| 51 | 
         
            +
            import torch
         
     | 
| 52 | 
         
            +
            from torch import nn
         
     | 
| 53 | 
         
            +
            import torchaudio.transforms as T
         
     | 
| 54 | 
         
            +
            from datasets import load_dataset
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            # loading our model weights
         
     | 
| 58 | 
         
            +
            model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
         
     | 
| 59 | 
         
            +
            # loading the corresponding preprocessor config
         
     | 
| 60 | 
         
            +
            processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            # load demo audio and set processor
         
     | 
| 63 | 
         
            +
            dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
         
     | 
| 64 | 
         
            +
            dataset = dataset.sort("id")
         
     | 
| 65 | 
         
            +
            sampling_rate = dataset.features["audio"].sampling_rate
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            resample_rate = processor.sampling_rate
         
     | 
| 68 | 
         
            +
            # make sure the sample_rate aligned
         
     | 
| 69 | 
         
            +
            if resample_rate != sampling_rate:
         
     | 
| 70 | 
         
            +
                print(f'setting rate from {sampling_rate} to {resample_rate}')
         
     | 
| 71 | 
         
            +
                resampler = T.Resample(sampling_rate, resample_rate)
         
     | 
| 72 | 
         
            +
            else:
         
     | 
| 73 | 
         
            +
                resampler = None
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            # audio file is decoded on the fly
         
     | 
| 76 | 
         
            +
            if resampler is None:
         
     | 
| 77 | 
         
            +
                input_audio = dataset[0]["audio"]["array"]
         
     | 
| 78 | 
         
            +
            else:
         
     | 
| 79 | 
         
            +
              input_audio = resampler(torch.from_numpy(dataset[0]["audio"]["array"]))
         
     | 
| 80 | 
         
            +
              
         
     | 
| 81 | 
         
            +
            inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
         
     | 
| 82 | 
         
            +
            with torch.no_grad():
         
     | 
| 83 | 
         
            +
                outputs = model(**inputs, output_hidden_states=True)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            # take a look at the output shape, there are 13 layers of representation
         
     | 
| 86 | 
         
            +
            # each layer performs differently in different downstream tasks, you should choose empirically
         
     | 
| 87 | 
         
            +
            all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
         
     | 
| 88 | 
         
            +
            print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            # for utterance level classification tasks, you can simply reduce the representation in time
         
     | 
| 91 | 
         
            +
            time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
         
     | 
| 92 | 
         
            +
            print(time_reduced_hidden_states.shape) # [13, 768]
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            # you can even use a learnable weighted average representation
         
     | 
| 95 | 
         
            +
            aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
         
     | 
| 96 | 
         
            +
            weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
         
     | 
| 97 | 
         
            +
            print(weighted_avg_hidden_states.shape) # [768]
         
     | 
| 98 | 
         
            +
            ```
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            # Citation
         
     | 
| 101 | 
         
            +
            ```shell
         
     | 
| 102 | 
         
            +
            @article{li2022large,
         
     | 
| 103 | 
         
            +
              title={Large-Scale Pretrained Model for Self-Supervised Music Audio Representation Learning},
         
     | 
| 104 | 
         
            +
              author={Li, Yizhi and Yuan, Ruibin and Zhang, Ge and Ma, Yinghao and Lin, Chenghua and Chen, Xingran and Ragni, Anton and Yin, Hanzhi and Hu, Zhijie and He, Haoyu and others},
         
     | 
| 105 | 
         
            +
              year={2022}
         
     | 
| 106 | 
         
            +
            }
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            @article{li2022map,
         
     | 
| 109 | 
         
            +
              title={MAP-Music2Vec: A Simple and Effective Baseline for Self-Supervised Music Audio Representation Learning},
         
     | 
| 110 | 
         
            +
              author={Li, Yizhi and Yuan, Ruibin and Zhang, Ge and Ma, Yinghao and Lin, Chenghua and Chen, Xingran and Ragni, Anton and Yin, Hanzhi and Hu, Zhijie and He, Haoyu and others},
         
     | 
| 111 | 
         
            +
              journal={arXiv preprint arXiv:2212.02508},
         
     | 
| 112 | 
         
            +
              year={2022}
         
     | 
| 113 | 
         
            +
            }
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            ```
         
     | 
    	
        MERT-v0-public/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        MERT-v0-public/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | 
         Binary file (161 Bytes). View file 
     | 
| 
         | 
    	
        MERT-v0-public/__pycache__/configuration_MERT.cpython-310.pyc
    ADDED
    
    | 
         Binary file (3.37 kB). View file 
     | 
| 
         | 
    	
        MERT-v0-public/__pycache__/modeling_MERT.cpython-310.pyc
    ADDED
    
    | 
         Binary file (10.3 kB). View file 
     | 
| 
         | 
    	
        MERT-v0-public/config.json
    ADDED
    
    | 
         @@ -0,0 +1,84 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "_name_or_path": "m-a-p/MERT-v0-public",
         
     | 
| 3 | 
         
            +
              "activation_dropout": 0.1,
         
     | 
| 4 | 
         
            +
              "apply_spec_augment": true,
         
     | 
| 5 | 
         
            +
              "architectures": [
         
     | 
| 6 | 
         
            +
                "MERTModel"
         
     | 
| 7 | 
         
            +
              ],
         
     | 
| 8 | 
         
            +
              "attention_relax": -1.0,
         
     | 
| 9 | 
         
            +
              "auto_map": {
         
     | 
| 10 | 
         
            +
                "AutoConfig": "configuration_MERT.MERTConfig",
         
     | 
| 11 | 
         
            +
                "AutoModel": "modeling_MERT.MERTModel"
         
     | 
| 12 | 
         
            +
              },
         
     | 
| 13 | 
         
            +
              "model_type": "mert_model",
         
     | 
| 14 | 
         
            +
              "attention_dropout": 0.1,
         
     | 
| 15 | 
         
            +
              "bos_token_id": 1,
         
     | 
| 16 | 
         
            +
              "classifier_proj_size": 256,
         
     | 
| 17 | 
         
            +
              "conv_bias": false,
         
     | 
| 18 | 
         
            +
              "conv_dim": [
         
     | 
| 19 | 
         
            +
                512,
         
     | 
| 20 | 
         
            +
                512,
         
     | 
| 21 | 
         
            +
                512,
         
     | 
| 22 | 
         
            +
                512,
         
     | 
| 23 | 
         
            +
                512,
         
     | 
| 24 | 
         
            +
                512,
         
     | 
| 25 | 
         
            +
                512
         
     | 
| 26 | 
         
            +
              ],
         
     | 
| 27 | 
         
            +
              "conv_kernel": [
         
     | 
| 28 | 
         
            +
                10,
         
     | 
| 29 | 
         
            +
                3,
         
     | 
| 30 | 
         
            +
                3,
         
     | 
| 31 | 
         
            +
                3,
         
     | 
| 32 | 
         
            +
                3,
         
     | 
| 33 | 
         
            +
                2,
         
     | 
| 34 | 
         
            +
                2
         
     | 
| 35 | 
         
            +
              ],
         
     | 
| 36 | 
         
            +
              "conv_stride": [
         
     | 
| 37 | 
         
            +
                5,
         
     | 
| 38 | 
         
            +
                2,
         
     | 
| 39 | 
         
            +
                2,
         
     | 
| 40 | 
         
            +
                2,
         
     | 
| 41 | 
         
            +
                2,
         
     | 
| 42 | 
         
            +
                2,
         
     | 
| 43 | 
         
            +
                2
         
     | 
| 44 | 
         
            +
              ],
         
     | 
| 45 | 
         
            +
              "ctc_loss_reduction": "sum",
         
     | 
| 46 | 
         
            +
              "ctc_zero_infinity": false,
         
     | 
| 47 | 
         
            +
              "do_stable_layer_norm": false,
         
     | 
| 48 | 
         
            +
              "eos_token_id": 2,
         
     | 
| 49 | 
         
            +
              "feat_extract_activation": "gelu",
         
     | 
| 50 | 
         
            +
              "feat_extract_dropout": 0.0,
         
     | 
| 51 | 
         
            +
              "feat_extract_norm": "group",
         
     | 
| 52 | 
         
            +
              "feat_proj_dropout": 0.1,
         
     | 
| 53 | 
         
            +
              "feat_proj_layer_norm": true,
         
     | 
| 54 | 
         
            +
              "feature_extractor_cqt": false,
         
     | 
| 55 | 
         
            +
              "feature_extractor_cqt_bins": 336,
         
     | 
| 56 | 
         
            +
              "final_dropout": 0.1,
         
     | 
| 57 | 
         
            +
              "gradient_checkpointing": false,
         
     | 
| 58 | 
         
            +
              "hidden_act": "gelu",
         
     | 
| 59 | 
         
            +
              "hidden_dropout": 0.1,
         
     | 
| 60 | 
         
            +
              "hidden_dropout_prob": 0.1,
         
     | 
| 61 | 
         
            +
              "hidden_size": 768,
         
     | 
| 62 | 
         
            +
              "initializer_range": 0.02,
         
     | 
| 63 | 
         
            +
              "intermediate_size": 3072,
         
     | 
| 64 | 
         
            +
              "layer_norm_eps": 1e-05,
         
     | 
| 65 | 
         
            +
              "layerdrop": 0.1,
         
     | 
| 66 | 
         
            +
              "mask_feature_length": 10,
         
     | 
| 67 | 
         
            +
              "mask_feature_min_masks": 0,
         
     | 
| 68 | 
         
            +
              "mask_feature_prob": 0.0,
         
     | 
| 69 | 
         
            +
              "mask_time_length": 10,
         
     | 
| 70 | 
         
            +
              "mask_time_min_masks": 2,
         
     | 
| 71 | 
         
            +
              "mask_time_prob": 0.05,
         
     | 
| 72 | 
         
            +
              "num_attention_heads": 12,
         
     | 
| 73 | 
         
            +
              "num_conv_pos_embedding_groups": 16,
         
     | 
| 74 | 
         
            +
              "num_conv_pos_embeddings": 128,
         
     | 
| 75 | 
         
            +
              "num_feat_extract_layers": 7,
         
     | 
| 76 | 
         
            +
              "num_hidden_layers": 12,
         
     | 
| 77 | 
         
            +
              "pad_token_id": 0,
         
     | 
| 78 | 
         
            +
              "sample_rate": 16000,
         
     | 
| 79 | 
         
            +
              "tokenizer_class": "Wav2Vec2CTCTokenizer",
         
     | 
| 80 | 
         
            +
              "torch_dtype": "float32",
         
     | 
| 81 | 
         
            +
              "transformers_version": "4.25.1",
         
     | 
| 82 | 
         
            +
              "use_weighted_layer_sum": false,
         
     | 
| 83 | 
         
            +
              "vocab_size": 32
         
     | 
| 84 | 
         
            +
            }
         
     | 
    	
        MERT-v0-public/configuration_MERT.py
    ADDED
    
    | 
         @@ -0,0 +1,131 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            MERT model configuration
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import functools
         
     | 
| 6 | 
         
            +
            import operator
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from transformers.configuration_utils import PretrainedConfig
         
     | 
| 9 | 
         
            +
            from transformers.utils import logging
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class MERTConfig(PretrainedConfig):
         
     | 
| 16 | 
         
            +
                r"""
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
                model_type = "mert_model"
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def __init__(
         
     | 
| 21 | 
         
            +
                    self,
         
     | 
| 22 | 
         
            +
                    vocab_size=32,
         
     | 
| 23 | 
         
            +
                    hidden_size=768,
         
     | 
| 24 | 
         
            +
                    num_hidden_layers=12,
         
     | 
| 25 | 
         
            +
                    num_attention_heads=12,
         
     | 
| 26 | 
         
            +
                    intermediate_size=3072,
         
     | 
| 27 | 
         
            +
                    hidden_act="gelu",
         
     | 
| 28 | 
         
            +
                    hidden_dropout=0.1,
         
     | 
| 29 | 
         
            +
                    activation_dropout=0.1,
         
     | 
| 30 | 
         
            +
                    attention_dropout=0.1,
         
     | 
| 31 | 
         
            +
                    feat_proj_layer_norm=True,
         
     | 
| 32 | 
         
            +
                    feat_proj_dropout=0.0,
         
     | 
| 33 | 
         
            +
                    final_dropout=0.1,
         
     | 
| 34 | 
         
            +
                    layerdrop=0.1,
         
     | 
| 35 | 
         
            +
                    initializer_range=0.02,
         
     | 
| 36 | 
         
            +
                    layer_norm_eps=1e-5,
         
     | 
| 37 | 
         
            +
                    feat_extract_norm="group",
         
     | 
| 38 | 
         
            +
                    feat_extract_activation="gelu",
         
     | 
| 39 | 
         
            +
                    conv_dim=(512, 512, 512, 512, 512, 512, 512),
         
     | 
| 40 | 
         
            +
                    conv_stride=(5, 2, 2, 2, 2, 2, 2),
         
     | 
| 41 | 
         
            +
                    conv_kernel=(10, 3, 3, 3, 3, 2, 2),
         
     | 
| 42 | 
         
            +
                    conv_bias=False,
         
     | 
| 43 | 
         
            +
                    num_conv_pos_embeddings=128,
         
     | 
| 44 | 
         
            +
                    num_conv_pos_embedding_groups=16,
         
     | 
| 45 | 
         
            +
                    do_stable_layer_norm=False,
         
     | 
| 46 | 
         
            +
                    apply_spec_augment=True,
         
     | 
| 47 | 
         
            +
                    mask_time_prob=0.05,
         
     | 
| 48 | 
         
            +
                    mask_time_length=10,
         
     | 
| 49 | 
         
            +
                    mask_time_min_masks=2,
         
     | 
| 50 | 
         
            +
                    mask_feature_prob=0.0,
         
     | 
| 51 | 
         
            +
                    mask_feature_length=10,
         
     | 
| 52 | 
         
            +
                    mask_feature_min_masks=0,
         
     | 
| 53 | 
         
            +
                    ctc_loss_reduction="sum",
         
     | 
| 54 | 
         
            +
                    ctc_zero_infinity=False,
         
     | 
| 55 | 
         
            +
                    use_weighted_layer_sum=False,
         
     | 
| 56 | 
         
            +
                    classifier_proj_size=256,
         
     | 
| 57 | 
         
            +
                    pad_token_id=0,
         
     | 
| 58 | 
         
            +
                    bos_token_id=1,
         
     | 
| 59 | 
         
            +
                    eos_token_id=2,
         
     | 
| 60 | 
         
            +
                    feature_extractor_cqt=False,
         
     | 
| 61 | 
         
            +
                    feature_extractor_cqt_bins=336,
         
     | 
| 62 | 
         
            +
                    deepnorm=False,
         
     | 
| 63 | 
         
            +
                    attention_relax=-1.0,
         
     | 
| 64 | 
         
            +
                    **kwargs
         
     | 
| 65 | 
         
            +
                ):
         
     | 
| 66 | 
         
            +
                    super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
         
     | 
| 67 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 68 | 
         
            +
                    self.feat_extract_norm = feat_extract_norm
         
     | 
| 69 | 
         
            +
                    self.feat_extract_activation = feat_extract_activation
         
     | 
| 70 | 
         
            +
                    self.conv_dim = list(conv_dim)
         
     | 
| 71 | 
         
            +
                    self.conv_stride = list(conv_stride)
         
     | 
| 72 | 
         
            +
                    self.conv_kernel = list(conv_kernel)
         
     | 
| 73 | 
         
            +
                    self.conv_bias = conv_bias
         
     | 
| 74 | 
         
            +
                    self.num_conv_pos_embeddings = num_conv_pos_embeddings
         
     | 
| 75 | 
         
            +
                    self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
         
     | 
| 76 | 
         
            +
                    self.num_feat_extract_layers = len(self.conv_dim)
         
     | 
| 77 | 
         
            +
                    self.num_hidden_layers = num_hidden_layers
         
     | 
| 78 | 
         
            +
                    self.intermediate_size = intermediate_size
         
     | 
| 79 | 
         
            +
                    self.hidden_act = hidden_act
         
     | 
| 80 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 81 | 
         
            +
                    self.hidden_dropout = hidden_dropout
         
     | 
| 82 | 
         
            +
                    self.attention_dropout = attention_dropout
         
     | 
| 83 | 
         
            +
                    self.activation_dropout = activation_dropout
         
     | 
| 84 | 
         
            +
                    self.feat_proj_layer_norm = feat_proj_layer_norm
         
     | 
| 85 | 
         
            +
                    self.feat_proj_dropout = feat_proj_dropout
         
     | 
| 86 | 
         
            +
                    self.final_dropout = final_dropout
         
     | 
| 87 | 
         
            +
                    self.layerdrop = layerdrop
         
     | 
| 88 | 
         
            +
                    self.layer_norm_eps = layer_norm_eps
         
     | 
| 89 | 
         
            +
                    self.initializer_range = initializer_range
         
     | 
| 90 | 
         
            +
                    self.vocab_size = vocab_size
         
     | 
| 91 | 
         
            +
                    self.do_stable_layer_norm = do_stable_layer_norm
         
     | 
| 92 | 
         
            +
                    self.use_weighted_layer_sum = use_weighted_layer_sum
         
     | 
| 93 | 
         
            +
                    self.classifier_proj_size = classifier_proj_size
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    if (
         
     | 
| 96 | 
         
            +
                        (len(self.conv_stride) != self.num_feat_extract_layers)
         
     | 
| 97 | 
         
            +
                        or (len(self.conv_kernel) != self.num_feat_extract_layers)
         
     | 
| 98 | 
         
            +
                        or (len(self.conv_dim) != self.num_feat_extract_layers)
         
     | 
| 99 | 
         
            +
                    ):
         
     | 
| 100 | 
         
            +
                        raise ValueError(
         
     | 
| 101 | 
         
            +
                            "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
         
     | 
| 102 | 
         
            +
                            " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
         
     | 
| 103 | 
         
            +
                            f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
         
     | 
| 104 | 
         
            +
                            f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
         
     | 
| 105 | 
         
            +
                        )
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
         
     | 
| 108 | 
         
            +
                    self.apply_spec_augment = apply_spec_augment
         
     | 
| 109 | 
         
            +
                    self.mask_time_prob = mask_time_prob
         
     | 
| 110 | 
         
            +
                    self.mask_time_length = mask_time_length
         
     | 
| 111 | 
         
            +
                    self.mask_time_min_masks = mask_time_min_masks
         
     | 
| 112 | 
         
            +
                    self.mask_feature_prob = mask_feature_prob
         
     | 
| 113 | 
         
            +
                    self.mask_feature_length = mask_feature_length
         
     | 
| 114 | 
         
            +
                    self.mask_feature_min_masks = mask_feature_min_masks
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    # ctc loss
         
     | 
| 117 | 
         
            +
                    self.ctc_loss_reduction = ctc_loss_reduction
         
     | 
| 118 | 
         
            +
                    self.ctc_zero_infinity = ctc_zero_infinity
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    # cqt feature extractor
         
     | 
| 121 | 
         
            +
                    self.feature_extractor_cqt = feature_extractor_cqt
         
     | 
| 122 | 
         
            +
                    self.feature_extractor_cqt_bins = feature_extractor_cqt_bins
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    # deepnorm: up-scale weighted residual conection + down-scale initial value transformer encoder
         
     | 
| 125 | 
         
            +
                    self.deepnorm = deepnorm
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    self.attention_relax = attention_relax
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                @property
         
     | 
| 130 | 
         
            +
                def inputs_to_logits_ratio(self):
         
     | 
| 131 | 
         
            +
                    return functools.reduce(operator.mul, self.conv_stride, 1)
         
     | 
    	
        MERT-v0-public/modeling_MERT.py
    ADDED
    
    | 
         @@ -0,0 +1,409 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            MERT model definition.
         
     | 
| 3 | 
         
            +
            We largely adapt codes from:
         
     | 
| 4 | 
         
            +
            1. https://github.com/huggingface/transformers/blob/main/src/transformers/models/hubert/modeling_hubert.py
         
     | 
| 5 | 
         
            +
            2. https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
         
     | 
| 6 | 
         
            +
            """
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from typing import Optional, Tuple, Union
         
     | 
| 9 | 
         
            +
            from transformers.modeling_outputs import BaseModelOutput
         
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
            from torch import nn
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from transformers.models.hubert.modeling_hubert import (
         
     | 
| 14 | 
         
            +
                HubertFeatureEncoder,
         
     | 
| 15 | 
         
            +
                HubertModel,
         
     | 
| 16 | 
         
            +
                HubertEncoderStableLayerNorm,
         
     | 
| 17 | 
         
            +
                HubertEncoder,
         
     | 
| 18 | 
         
            +
                HubertEncoderLayer,
         
     | 
| 19 | 
         
            +
                HubertPositionalConvEmbedding,
         
     | 
| 20 | 
         
            +
                HubertAttention,
         
     | 
| 21 | 
         
            +
                HubertFeedForward,
         
     | 
| 22 | 
         
            +
            )
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            try:
         
     | 
| 25 | 
         
            +
                from nnAudio import features as nnAudioFeatures
         
     | 
| 26 | 
         
            +
                NNAUDIO_INSTALLED=True
         
     | 
| 27 | 
         
            +
            except:
         
     | 
| 28 | 
         
            +
                print("WARNING: feature_extractor_cqt requires the libray 'nnAudio'")
         
     | 
| 29 | 
         
            +
                NNAUDIO_INSTALLED=False
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            from .configuration_MERT import MERTConfig
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            class MERTFeatureProjection(nn.Module):
         
     | 
| 34 | 
         
            +
                def __init__(self, config):
         
     | 
| 35 | 
         
            +
                    super().__init__()
         
     | 
| 36 | 
         
            +
                    self.feat_proj_layer_norm = config.feat_proj_layer_norm
         
     | 
| 37 | 
         
            +
                    self.feature_extractor_cqt = config.feature_extractor_cqt
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    if self.feature_extractor_cqt:
         
     | 
| 40 | 
         
            +
                        # v3 concat features
         
     | 
| 41 | 
         
            +
                        self.feature_dimension = config.conv_dim[-1] + config.feature_extractor_cqt_bins
         
     | 
| 42 | 
         
            +
                        print(f"feature dimention: {self.feature_dimension}")
         
     | 
| 43 | 
         
            +
                    else:
         
     | 
| 44 | 
         
            +
                        self.feature_dimension = config.conv_dim[-1]
         
     | 
| 45 | 
         
            +
                    if self.feat_proj_layer_norm:
         
     | 
| 46 | 
         
            +
                        self.layer_norm = nn.LayerNorm(self.feature_dimension, eps=config.layer_norm_eps)
         
     | 
| 47 | 
         
            +
                    self.projection = nn.Linear(self.feature_dimension, config.hidden_size)
         
     | 
| 48 | 
         
            +
                    self.dropout = nn.Dropout(config.feat_proj_dropout)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 51 | 
         
            +
                    # non-projected hidden states are needed for quantization
         
     | 
| 52 | 
         
            +
                    if self.feat_proj_layer_norm:
         
     | 
| 53 | 
         
            +
                        hidden_states = self.layer_norm(hidden_states)
         
     | 
| 54 | 
         
            +
                    hidden_states = self.projection(hidden_states)
         
     | 
| 55 | 
         
            +
                    hidden_states = self.dropout(hidden_states)
         
     | 
| 56 | 
         
            +
                    return hidden_states
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            class MERTModel(HubertModel):
         
     | 
| 59 | 
         
            +
                # overwrite config class
         
     | 
| 60 | 
         
            +
                config_class = MERTConfig
         
     | 
| 61 | 
         
            +
                base_model_prefix = "mert_model"
         
     | 
| 62 | 
         
            +
                def __init__(
         
     | 
| 63 | 
         
            +
                    self,
         
     | 
| 64 | 
         
            +
                    config: MERTConfig,
         
     | 
| 65 | 
         
            +
                ) -> None:
         
     | 
| 66 | 
         
            +
                    """ 
         
     | 
| 67 | 
         
            +
                    initialize the with the grandparent method HubertPreTrainedModel.__init__()
         
     | 
| 68 | 
         
            +
                    and modify the HuBERTModel.__init__()
         
     | 
| 69 | 
         
            +
                    """
         
     | 
| 70 | 
         
            +
                    super(HubertModel, self).__init__(config)
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    self.config = config
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    self.feature_extractor = HubertFeatureEncoder(config)
         
     | 
| 75 | 
         
            +
                    self.feature_projection = MERTFeatureProjection(config) # replace Feature Projection for introcuing new feature
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    if self.config.feature_extractor_cqt:
         
     | 
| 78 | 
         
            +
                        assert NNAUDIO_INSTALLED, "ERROR: feature_extractor_cqt requires the libray 'nnAudio', try after `pip install nnAudio` "
         
     | 
| 79 | 
         
            +
                        print('initializing cqt extractor for MERT')            
         
     | 
| 80 | 
         
            +
                        self.feature_extractor_cqt = nnAudioFeatures.cqt.CQT(sr=self.config.sample_rate, hop_length=self.config.sample_rate//50, fmin=32.7, 
         
     | 
| 81 | 
         
            +
                                fmax=None, n_bins=self.config.feature_extractor_cqt_bins, bins_per_octave=self.config.feature_extractor_cqt_bins//7, 
         
     | 
| 82 | 
         
            +
                                filter_scale=1, norm=1, window='hann', center=True, 
         
     | 
| 83 | 
         
            +
                                pad_mode='constant', trainable=False, 
         
     | 
| 84 | 
         
            +
                                output_format='Magnitude', verbose=True)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
         
     | 
| 87 | 
         
            +
                        self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    
         
     | 
| 90 | 
         
            +
                    if config.do_stable_layer_norm:
         
     | 
| 91 | 
         
            +
                        assert not config.deepnorm, "must use post-layer_norm with deepnorm"
         
     | 
| 92 | 
         
            +
                        self.encoder = HubertEncoderStableLayerNorm(config)
         
     | 
| 93 | 
         
            +
                    else:
         
     | 
| 94 | 
         
            +
                        if config.deepnorm:
         
     | 
| 95 | 
         
            +
                            self.encoder = HubertEncoder_extend(config)
         
     | 
| 96 | 
         
            +
                        else:
         
     | 
| 97 | 
         
            +
                            self.encoder = HubertEncoder(config)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    # Initialize weights and apply final processing
         
     | 
| 100 | 
         
            +
                    self.post_init()
         
     | 
| 101 | 
         
            +
                
         
     | 
| 102 | 
         
            +
                def forward(self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) -> Union[Tuple, BaseModelOutput]:
         
     | 
| 103 | 
         
            +
                    
         
     | 
| 104 | 
         
            +
                    # return super().forward(input_values, attention_mask, mask_time_indices, output_attentions, output_hidden_states, return_dict)
         
     | 
| 105 | 
         
            +
                    
         
     | 
| 106 | 
         
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         
     | 
| 107 | 
         
            +
                    output_hidden_states = (
         
     | 
| 108 | 
         
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         
     | 
| 109 | 
         
            +
                    )
         
     | 
| 110 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    extract_features = self.feature_extractor(input_values)
         
     | 
| 113 | 
         
            +
                    extract_features = extract_features.transpose(1, 2)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    # add additional cqt features for transformer input
         
     | 
| 116 | 
         
            +
                    if self.config.feature_extractor_cqt:
         
     | 
| 117 | 
         
            +
                        features_cqt = self.feature_extractor_cqt(input_values).transpose(1, 2)
         
     | 
| 118 | 
         
            +
                        features_cqt = features_cqt[:,:extract_features.shape[1],:] # align shape
         
     | 
| 119 | 
         
            +
                        # # v2
         
     | 
| 120 | 
         
            +
                        # features_cqt = self.post_cqt_feature_proj(features_cqt)
         
     | 
| 121 | 
         
            +
                        # extract_features = self.feature_projection.layer_norm(extract_features) + self.feature_projection.layer_norm(features_cqt) #v2
         
     | 
| 122 | 
         
            +
                        # v3
         
     | 
| 123 | 
         
            +
                        extract_features = torch.cat([extract_features,features_cqt], 2)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 126 | 
         
            +
                        # compute reduced attention_mask corresponding to feature vectors
         
     | 
| 127 | 
         
            +
                        attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    hidden_states = self.feature_projection(extract_features)
         
     | 
| 130 | 
         
            +
                    hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    encoder_outputs = self.encoder(
         
     | 
| 133 | 
         
            +
                        hidden_states,
         
     | 
| 134 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 135 | 
         
            +
                        output_attentions=output_attentions,
         
     | 
| 136 | 
         
            +
                        output_hidden_states=output_hidden_states,
         
     | 
| 137 | 
         
            +
                        return_dict=return_dict,
         
     | 
| 138 | 
         
            +
                    )
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    hidden_states = encoder_outputs[0] # take last_hidden from encoder output
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    if not return_dict:
         
     | 
| 143 | 
         
            +
                        return (hidden_states,) + encoder_outputs[1:]
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    return BaseModelOutput(
         
     | 
| 146 | 
         
            +
                        last_hidden_state=hidden_states,
         
     | 
| 147 | 
         
            +
                        hidden_states=encoder_outputs.hidden_states,
         
     | 
| 148 | 
         
            +
                        attentions=encoder_outputs.attentions,
         
     | 
| 149 | 
         
            +
                    )
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            class HubertEncoder_extend(HubertEncoder):
         
     | 
| 153 | 
         
            +
                def __init__(self, config):
         
     | 
| 154 | 
         
            +
                    # super().__init__()
         
     | 
| 155 | 
         
            +
                    # call nn module initialization
         
     | 
| 156 | 
         
            +
                    nn.Module.__init__(self)
         
     | 
| 157 | 
         
            +
                    # super(HubertEncoder_extend, self).__init__()
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    self.config = config
         
     | 
| 160 | 
         
            +
                    self.pos_conv_embed = HubertPositionalConvEmbedding(config)
         
     | 
| 161 | 
         
            +
                    self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         
     | 
| 162 | 
         
            +
                    self.dropout = nn.Dropout(config.hidden_dropout)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    
         
     | 
| 165 | 
         
            +
                    self.layers = nn.ModuleList([HubertEncoderLayerExtend(config) for _ in range(config.num_hidden_layers)])
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    if config.deepnorm:
         
     | 
| 170 | 
         
            +
                        import math
         
     | 
| 171 | 
         
            +
                        init_scale = math.pow(8.0 * config.num_hidden_layers, 0.25)
         
     | 
| 172 | 
         
            +
                        for name, p in self.named_parameters():
         
     | 
| 173 | 
         
            +
                            if (
         
     | 
| 174 | 
         
            +
                                "feed_forward.intermediate_dense" in name
         
     | 
| 175 | 
         
            +
                                or "feed_forward.output_dense" in name
         
     | 
| 176 | 
         
            +
                                or "out_proj" in name
         
     | 
| 177 | 
         
            +
                                or "v_proj" in name
         
     | 
| 178 | 
         
            +
                            ):
         
     | 
| 179 | 
         
            +
                                p.data.div_(init_scale)
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
            class HubertEncoderLayerExtend(HubertEncoderLayer):
         
     | 
| 182 | 
         
            +
                def __init__(self, config):
         
     | 
| 183 | 
         
            +
                    nn.Module.__init__(self)
         
     | 
| 184 | 
         
            +
                    # super(HubertEncoderLayerExtend, self).__init__()
         
     | 
| 185 | 
         
            +
                    if config.attention_relax > 0 :
         
     | 
| 186 | 
         
            +
                        self.attention = HubertAttention_extend(
         
     | 
| 187 | 
         
            +
                            embed_dim=config.hidden_size,
         
     | 
| 188 | 
         
            +
                            num_heads=config.num_attention_heads,
         
     | 
| 189 | 
         
            +
                            dropout=config.attention_dropout,
         
     | 
| 190 | 
         
            +
                            is_decoder=False,
         
     | 
| 191 | 
         
            +
                            attention_relax=config.attention_relax,
         
     | 
| 192 | 
         
            +
                        )
         
     | 
| 193 | 
         
            +
                    else:    
         
     | 
| 194 | 
         
            +
                        self.attention = HubertAttention(
         
     | 
| 195 | 
         
            +
                            embed_dim=config.hidden_size,
         
     | 
| 196 | 
         
            +
                            num_heads=config.num_attention_heads,
         
     | 
| 197 | 
         
            +
                            dropout=config.attention_dropout,
         
     | 
| 198 | 
         
            +
                            is_decoder=False,
         
     | 
| 199 | 
         
            +
                        )
         
     | 
| 200 | 
         
            +
                    self.dropout = nn.Dropout(config.hidden_dropout)
         
     | 
| 201 | 
         
            +
                    self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         
     | 
| 202 | 
         
            +
                    self.feed_forward = HubertFeedForward(config)
         
     | 
| 203 | 
         
            +
                    self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    if config.deepnorm:
         
     | 
| 206 | 
         
            +
                        import math
         
     | 
| 207 | 
         
            +
                        self.residual_alpha = math.pow(2.0 * config.num_hidden_layers, 0.25)
         
     | 
| 208 | 
         
            +
                    else:
         
     | 
| 209 | 
         
            +
                        self.residual_alpha = 1.0
         
     | 
| 210 | 
         
            +
                
         
     | 
| 211 | 
         
            +
                def residual_connection(self, x, residual):
         
     | 
| 212 | 
         
            +
                    '''
         
     | 
| 213 | 
         
            +
                    residual: input before f()
         
     | 
| 214 | 
         
            +
                    x: output of f(residual)
         
     | 
| 215 | 
         
            +
                    '''
         
     | 
| 216 | 
         
            +
                    return residual * self.residual_alpha + x
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                def forward(self, hidden_states, attention_mask=None, output_attentions=False):
         
     | 
| 219 | 
         
            +
                    attn_residual = hidden_states
         
     | 
| 220 | 
         
            +
                    hidden_states, attn_weights, _ = self.attention(
         
     | 
| 221 | 
         
            +
                        hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
         
     | 
| 222 | 
         
            +
                    )
         
     | 
| 223 | 
         
            +
                    hidden_states = self.dropout(hidden_states)
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    # hidden_states = attn_residual + hidden_states
         
     | 
| 226 | 
         
            +
                    hidden_states = self.residual_connection(hidden_states, attn_residual)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    hidden_states = self.layer_norm(hidden_states)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    # hidden_states = hidden_states + self.feed_forward(hidden_states)
         
     | 
| 231 | 
         
            +
                    ffn_residual = hidden_states
         
     | 
| 232 | 
         
            +
                    hidden_states = self.feed_forward(hidden_states)
         
     | 
| 233 | 
         
            +
                    hidden_states = self.residual_connection(hidden_states, ffn_residual)
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                    hidden_states = self.final_layer_norm(hidden_states)
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    outputs = (hidden_states,)
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                    if output_attentions:
         
     | 
| 240 | 
         
            +
                        outputs += (attn_weights,)
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                    return outputs
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
            class HubertAttention_extend(nn.Module):
         
     | 
| 246 | 
         
            +
                def __init__(
         
     | 
| 247 | 
         
            +
                    self,
         
     | 
| 248 | 
         
            +
                    embed_dim: int,
         
     | 
| 249 | 
         
            +
                    num_heads: int,
         
     | 
| 250 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 251 | 
         
            +
                    is_decoder: bool = False,
         
     | 
| 252 | 
         
            +
                    bias: bool = True,
         
     | 
| 253 | 
         
            +
                    attention_relax: float = -1.0,
         
     | 
| 254 | 
         
            +
                ):
         
     | 
| 255 | 
         
            +
                    super().__init__()
         
     | 
| 256 | 
         
            +
                    # nn.Module.__init__(self)
         
     | 
| 257 | 
         
            +
                    self.embed_dim = embed_dim
         
     | 
| 258 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 259 | 
         
            +
                    self.dropout = dropout
         
     | 
| 260 | 
         
            +
                    self.head_dim = embed_dim // num_heads
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    if (self.head_dim * num_heads) != self.embed_dim:
         
     | 
| 263 | 
         
            +
                        raise ValueError(
         
     | 
| 264 | 
         
            +
                            f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
         
     | 
| 265 | 
         
            +
                            f" and `num_heads`: {num_heads})."
         
     | 
| 266 | 
         
            +
                        )
         
     | 
| 267 | 
         
            +
                    self.scaling = self.head_dim**-0.5
         
     | 
| 268 | 
         
            +
                    self.is_decoder = is_decoder
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                    self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
         
     | 
| 271 | 
         
            +
                    self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
         
     | 
| 272 | 
         
            +
                    self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
         
     | 
| 273 | 
         
            +
                    self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    if attention_relax > 0:
         
     | 
| 276 | 
         
            +
                        self.attention_relax = attention_relax
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
         
     | 
| 279 | 
         
            +
                    return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                def forward(
         
     | 
| 282 | 
         
            +
                    self,
         
     | 
| 283 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 284 | 
         
            +
                    key_value_states: Optional[torch.Tensor] = None,
         
     | 
| 285 | 
         
            +
                    past_key_value: Optional[Tuple[torch.Tensor]] = None,
         
     | 
| 286 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 287 | 
         
            +
                    layer_head_mask: Optional[torch.Tensor] = None,
         
     | 
| 288 | 
         
            +
                    output_attentions: bool = False,
         
     | 
| 289 | 
         
            +
                ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         
     | 
| 290 | 
         
            +
                    """Input shape: Batch x Time x Channel"""
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    # if key_value_states are provided this layer is used as a cross-attention layer
         
     | 
| 293 | 
         
            +
                    # for the decoder
         
     | 
| 294 | 
         
            +
                    is_cross_attention = key_value_states is not None
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    bsz, tgt_len, _ = hidden_states.size()
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    # get query proj
         
     | 
| 299 | 
         
            +
                    query_states = self.q_proj(hidden_states) * self.scaling
         
     | 
| 300 | 
         
            +
                    # get key, value proj
         
     | 
| 301 | 
         
            +
                    # `past_key_value[0].shape[2] == key_value_states.shape[1]`
         
     | 
| 302 | 
         
            +
                    # is checking that the `sequence_length` of the `past_key_value` is the same as
         
     | 
| 303 | 
         
            +
                    # the provided `key_value_states` to support prefix tuning
         
     | 
| 304 | 
         
            +
                    if (
         
     | 
| 305 | 
         
            +
                        is_cross_attention
         
     | 
| 306 | 
         
            +
                        and past_key_value is not None
         
     | 
| 307 | 
         
            +
                        and past_key_value[0].shape[2] == key_value_states.shape[1]
         
     | 
| 308 | 
         
            +
                    ):
         
     | 
| 309 | 
         
            +
                        # reuse k,v, cross_attentions
         
     | 
| 310 | 
         
            +
                        key_states = past_key_value[0]
         
     | 
| 311 | 
         
            +
                        value_states = past_key_value[1]
         
     | 
| 312 | 
         
            +
                    elif is_cross_attention:
         
     | 
| 313 | 
         
            +
                        # cross_attentions
         
     | 
| 314 | 
         
            +
                        key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
         
     | 
| 315 | 
         
            +
                        value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
         
     | 
| 316 | 
         
            +
                    elif past_key_value is not None:
         
     | 
| 317 | 
         
            +
                        # reuse k, v, self_attention
         
     | 
| 318 | 
         
            +
                        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
         
     | 
| 319 | 
         
            +
                        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
         
     | 
| 320 | 
         
            +
                        key_states = torch.cat([past_key_value[0], key_states], dim=2)
         
     | 
| 321 | 
         
            +
                        value_states = torch.cat([past_key_value[1], value_states], dim=2)
         
     | 
| 322 | 
         
            +
                    else:
         
     | 
| 323 | 
         
            +
                        # self_attention
         
     | 
| 324 | 
         
            +
                        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
         
     | 
| 325 | 
         
            +
                        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                    if self.is_decoder:
         
     | 
| 328 | 
         
            +
                        # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
         
     | 
| 329 | 
         
            +
                        # Further calls to cross_attention layer can then reuse all cross-attention
         
     | 
| 330 | 
         
            +
                        # key/value_states (first "if" case)
         
     | 
| 331 | 
         
            +
                        # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
         
     | 
| 332 | 
         
            +
                        # all previous decoder key/value_states. Further calls to uni-directional self-attention
         
     | 
| 333 | 
         
            +
                        # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
         
     | 
| 334 | 
         
            +
                        # if encoder bi-directional self-attention `past_key_value` is always `None`
         
     | 
| 335 | 
         
            +
                        past_key_value = (key_states, value_states)
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                    proj_shape = (bsz * self.num_heads, -1, self.head_dim)
         
     | 
| 338 | 
         
            +
                    query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
         
     | 
| 339 | 
         
            +
                    key_states = key_states.view(*proj_shape)
         
     | 
| 340 | 
         
            +
                    value_states = value_states.view(*proj_shape)
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    src_len = key_states.size(1)
         
     | 
| 343 | 
         
            +
                    attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                    if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
         
     | 
| 346 | 
         
            +
                        raise ValueError(
         
     | 
| 347 | 
         
            +
                            f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
         
     | 
| 348 | 
         
            +
                            f" {attn_weights.size()}"
         
     | 
| 349 | 
         
            +
                        )
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 352 | 
         
            +
                        if attention_mask.size() != (bsz, 1, tgt_len, src_len):
         
     | 
| 353 | 
         
            +
                            raise ValueError(
         
     | 
| 354 | 
         
            +
                                f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
         
     | 
| 355 | 
         
            +
                            )
         
     | 
| 356 | 
         
            +
                        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
         
     | 
| 357 | 
         
            +
                        attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                    if self.attention_relax > 0:
         
     | 
| 360 | 
         
            +
                        # => (bsz, self.num_heads, tgt_len, src_len)
         
     | 
| 361 | 
         
            +
                        # attn_weights_relax = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)/self.attention_relax
         
     | 
| 362 | 
         
            +
                        # => (bsz*self.num_heads, tgt_len, src_len)
         
     | 
| 363 | 
         
            +
                        attn_weights_relax = attn_weights / self.attention_relax
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                        # => (bsz* self.num_heads, tgt_len, 1)
         
     | 
| 366 | 
         
            +
                        attn_max_relax = torch.max(attn_weights_relax, dim=-1, keepdim=False).unsqueeze(2)
         
     | 
| 367 | 
         
            +
                        attn_weights = (attn_weights_relax - attn_max_relax) * self.attention_relax
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                    if layer_head_mask is not None:
         
     | 
| 372 | 
         
            +
                        if layer_head_mask.size() != (self.num_heads,):
         
     | 
| 373 | 
         
            +
                            raise ValueError(
         
     | 
| 374 | 
         
            +
                                f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
         
     | 
| 375 | 
         
            +
                                f" {layer_head_mask.size()}"
         
     | 
| 376 | 
         
            +
                            )
         
     | 
| 377 | 
         
            +
                        attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
         
     | 
| 378 | 
         
            +
                        attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                    if output_attentions:
         
     | 
| 381 | 
         
            +
                        # this operation is a bit awkward, but it's required to
         
     | 
| 382 | 
         
            +
                        # make sure that attn_weights keeps its gradient.
         
     | 
| 383 | 
         
            +
                        # In order to do so, attn_weights have to be reshaped
         
     | 
| 384 | 
         
            +
                        # twice and have to be reused in the following
         
     | 
| 385 | 
         
            +
                        attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
         
     | 
| 386 | 
         
            +
                        attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
         
     | 
| 387 | 
         
            +
                    else:
         
     | 
| 388 | 
         
            +
                        attn_weights_reshaped = None
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
                    attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                    attn_output = torch.bmm(attn_probs, value_states)
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                    if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
         
     | 
| 395 | 
         
            +
                        raise ValueError(
         
     | 
| 396 | 
         
            +
                            f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
         
     | 
| 397 | 
         
            +
                            f" {attn_output.size()}"
         
     | 
| 398 | 
         
            +
                        )
         
     | 
| 399 | 
         
            +
             
     | 
| 400 | 
         
            +
                    attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
         
     | 
| 401 | 
         
            +
                    attn_output = attn_output.transpose(1, 2)
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
                    # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
         
     | 
| 404 | 
         
            +
                    # partitioned aross GPUs when using tensor-parallelism.
         
     | 
| 405 | 
         
            +
                    attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                    attn_output = self.out_proj(attn_output)
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                    return attn_output, attn_weights_reshaped, past_key_value
         
     | 
    	
        MERT-v0-public/preprocessor_config.json
    ADDED
    
    | 
         @@ -0,0 +1,9 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "do_normalize": false,
         
     | 
| 3 | 
         
            +
              "feature_extractor_type": "Wav2Vec2FeatureExtractor",
         
     | 
| 4 | 
         
            +
              "feature_size": 1,
         
     | 
| 5 | 
         
            +
              "padding_side": "right",
         
     | 
| 6 | 
         
            +
              "padding_value": 0,
         
     | 
| 7 | 
         
            +
              "return_attention_mask": true,
         
     | 
| 8 | 
         
            +
              "sampling_rate": 16000
         
     | 
| 9 | 
         
            +
            }
         
     | 
    	
        MERT-v0-public/pytorch_model.bin
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:9b25bde740483579d9895f35d074a949f6593ef48449b6d76e26ee3c0e5e9acb
         
     | 
| 3 | 
         
            +
            size 377552987
         
     | 
    	
        __pycache__/app.cpython-310.pyc
    CHANGED
    
    | 
         Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ 
     | 
| 
         | 
    	
        app.py
    CHANGED
    
    | 
         @@ -1,4 +1,5 @@ 
     | 
|
| 1 | 
         
             
            import gradio as gr
         
     | 
| 
         | 
|
| 2 | 
         
             
            from transformers import Wav2Vec2FeatureExtractor
         
     | 
| 3 | 
         
             
            from transformers import AutoModel
         
     | 
| 4 | 
         
             
            import torch
         
     | 
| 
         @@ -6,6 +7,10 @@ from torch import nn 
     | 
|
| 6 | 
         
             
            import torchaudio
         
     | 
| 7 | 
         
             
            import torchaudio.transforms as T
         
     | 
| 8 | 
         
             
            import logging
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 9 | 
         
             
            # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
         
     | 
| 10 | 
         | 
| 11 | 
         | 
| 
         @@ -21,7 +26,7 @@ logger.addHandler(ch) 
     | 
|
| 21 | 
         | 
| 22 | 
         | 
| 23 | 
         
             
            inputs = [gr.components.Audio(type="filepath", label="Add music audio file"), 
         
     | 
| 24 | 
         
            -
                      gr. 
     | 
| 25 | 
         
             
                      ]
         
     | 
| 26 | 
         
             
            outputs = [gr.components.Textbox()]
         
     | 
| 27 | 
         
             
            # outputs = [gr.components.Textbox(), transcription_df]
         
     | 
| 
         @@ -33,10 +38,12 @@ audio_examples = [ 
     | 
|
| 33 | 
         
             
                # ["input/example-2.wav"],
         
     | 
| 34 | 
         
             
            ]
         
     | 
| 35 | 
         | 
| 36 | 
         
            -
            # Load the model
         
     | 
| 37 | 
         
            -
            model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
         
     | 
| 38 | 
         
            -
            #  
     | 
| 39 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 40 | 
         | 
| 41 | 
         
             
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
         
     | 
| 42 | 
         
             
            model.to(device)
         
     | 
| 
         | 
|
| 1 | 
         
             
            import gradio as gr
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
             
            from transformers import Wav2Vec2FeatureExtractor
         
     | 
| 4 | 
         
             
            from transformers import AutoModel
         
     | 
| 5 | 
         
             
            import torch
         
     | 
| 
         | 
|
| 7 | 
         
             
            import torchaudio
         
     | 
| 8 | 
         
             
            import torchaudio.transforms as T
         
     | 
| 9 | 
         
             
            import logging
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import importlib 
         
     | 
| 12 | 
         
            +
            modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT")
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
             
            # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
         
     | 
| 15 | 
         | 
| 16 | 
         | 
| 
         | 
|
| 26 | 
         | 
| 27 | 
         | 
| 28 | 
         
             
            inputs = [gr.components.Audio(type="filepath", label="Add music audio file"), 
         
     | 
| 29 | 
         
            +
                      gr.components.Audio(source="microphone",optional=True, type="filepath"),
         
     | 
| 30 | 
         
             
                      ]
         
     | 
| 31 | 
         
             
            outputs = [gr.components.Textbox()]
         
     | 
| 32 | 
         
             
            # outputs = [gr.components.Textbox(), transcription_df]
         
     | 
| 
         | 
|
| 38 | 
         
             
                # ["input/example-2.wav"],
         
     | 
| 39 | 
         
             
            ]
         
     | 
| 40 | 
         | 
| 41 | 
         
            +
            # Load the model and the corresponding preprocessor config
         
     | 
| 42 | 
         
            +
            # model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
         
     | 
| 43 | 
         
            +
            # processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
         
     | 
| 44 | 
         
            +
            model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public")
         
     | 
| 45 | 
         
            +
            processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public")
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         | 
| 48 | 
         
             
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
         
     | 
| 49 | 
         
             
            model.to(device)
         
     |