Spaces:
Paused
Paused
| # Gaepago model V1 (CPU Test) | |
| # import package | |
| from transformers import AutoModelForAudioClassification | |
| from transformers import AutoFeatureExtractor | |
| from transformers import pipeline | |
| from datasets import Dataset, Audio | |
| import gradio as gr | |
| import torch | |
| from utils.postprocess import text_mapping,text_encoding | |
| import json | |
| import os | |
| # Set model & Dataset NM | |
| MODEL_NAME = "Gae8J/gaepago-20" | |
| DATASET_NAME = "Gae8J/modeling_v1" | |
| TEXT_LABEL = "text_label.json" | |
| # Import Model & feature extractor | |
| # model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME) | |
| from transformers import AutoConfig | |
| config = AutoConfig.from_pretrained(MODEL_NAME) | |
| model = torch.jit.load(f"./model/gaepago-20-lite/model_quant_int8.pt") | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) | |
| # ๋ชจ๋ธ cpu๋ก ๋ณ๊ฒฝํ์ฌ ์งํ | |
| model.to("cpu") | |
| # TEXT LABEL ๋ถ๋ฌ์ค๊ธฐ | |
| with open(TEXT_LABEL,"r",encoding='utf-8') as f: | |
| text_label = json.load(f) | |
| # Gaepago Inference Model function | |
| def gaepago_fn(tmp_audio_dir): | |
| # if os.path.isfile(tmp_audio_dir): | |
| print(tmp_audio_dir) | |
| # else: | |
| # ## khan test | |
| # tmp_audio_dir = './sample/bark_sample.wav' | |
| audio_dataset = Dataset.from_dict({"audio": [tmp_audio_dir]}).cast_column("audio", Audio(sampling_rate=16000)) | |
| inputs = feature_extractor(audio_dataset[0]["audio"]["array"] | |
| ,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"] | |
| ,return_tensors="pt") | |
| with torch.no_grad(): | |
| # logits = model(**inputs).logits | |
| logits = model(**inputs)["logits"] | |
| # predicted_class_ids = torch.argmax(logits).item() | |
| # predicted_label = model.config.id2label[predicted_class_ids] | |
| predicted_class_ids = torch.argmax(logits).item() | |
| predicted_label = config.id2label[predicted_class_ids] | |
| # add postprocessing | |
| ## 1. text mapping | |
| output = text_mapping(predicted_label,text_label) | |
| # output = text_encoding(output) | |
| return output | |
| # Main | |
| example_list = ["./sample/bark_sample.wav" | |
| ,"./sample/growling_sample.wav" | |
| ,"./sample/howl_sample.wav" | |
| ,"./sample/panting_sample.wav" | |
| ,"./sample/whimper_sample.wav" | |
| ] | |
| main_api = gr.Blocks() | |
| with main_api as demo: | |
| gr.Markdown("## 8J Gaepago Demo(with CPU)") | |
| with gr.Row(): | |
| audio = gr.Audio(source="microphone", type="filepath" | |
| ,label='๋ น์๋ฒํผ์ ๋๋ฌ ์ด์ฝ๊ฐ ํ๋ ๋ง์ ๋ค๋ ค์ฃผ์ธ์') | |
| transcription = gr.Textbox(label='์ง๊ธ ์ด์ฝ๊ฐ ํ๋ ๋ง์...') | |
| b1 = gr.Button("๊ฐ์์ง ์ธ์ด ๋ฒ์ญ!") | |
| b1.click(gaepago_fn, inputs=audio, outputs=transcription,api_name="predict") | |
| examples = gr.Examples(examples=example_list, inputs=[audio]) | |
| demo.launch(show_error=True) |