Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| import pandas as pd | |
| from googleapiclient.discovery import build | |
| import re | |
| import threading | |
| def get_video_id(url): | |
| link_pref = "https://www.youtube.com/watch?v=" | |
| link_pref_2 = "https://youtu.be/" | |
| if url.startswith(link_pref): | |
| end = url.find("&") | |
| return url[len(link_pref): end if end != -1 else None] | |
| elif url.startswith(link_pref_2): | |
| end = url.find("?") | |
| return url[len(link_pref_2): end if end != -1 else None] | |
| else: | |
| raise Exception("YOU NEED TO PASTE YOUTUBE LINK 🤡🤡🤡!!! \nYOU PASTED " + link) | |
| def remove_repeated_substrings(s, n, k): | |
| if k < 2 or n < 1 or len(s) < n * k: | |
| return s | |
| original = s | |
| max_m = len(original) // k | |
| for m in range(max_m, n - 1, -1): | |
| pattern = re.compile(r"((.{" + str(m) + r"}))\1{" + str(k - 1) + r",}") | |
| while True: | |
| new_s, replacements = pattern.subn(r"\1", original) | |
| if replacements == 0: | |
| break | |
| original = new_s | |
| return original | |
| def generate_answer(_model, _tokenizer, prompt_text, _device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), max_new_tokens=100, temperature=1.3, top_k=90, top_p=0.7, do_sample=True): | |
| prompt_separator = "\n>>> Prompt:\n" | |
| answer_separator = "\n>>> Answer:\n" | |
| input_text = prompt_separator + prompt_text + answer_separator | |
| inputs = _tokenizer(input_text, return_tensors="pt", truncation=True) | |
| inputs = {k: v.to(_device) for k, v in inputs.items()} | |
| input_ids = inputs["input_ids"] | |
| input_length = input_ids.shape[1] | |
| model_max_length = _tokenizer.model_max_length if hasattr(_tokenizer, 'model_max_length') else 1024 | |
| if input_length >= model_max_length - max_new_tokens: | |
| allowed_input_length = model_max_length - max_new_tokens - 5 | |
| input_ids = input_ids[:, -allowed_input_length:] | |
| inputs['input_ids'] = input_ids | |
| inputs['attention_mask'] = inputs['attention_mask'][:, -allowed_input_length:] | |
| streamer = TextIteratorStreamer(_tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| pad_token_id=_tokenizer.pad_token_id, | |
| eos_token_id=_tokenizer.eos_token_id, | |
| streamer=streamer, | |
| ) | |
| thread = threading.Thread(target=_model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_text = "" | |
| for new_text in streamer: | |
| generated_text += new_text | |
| yield generated_text | |
| processed_text = remove_repeated_substrings(generated_text, 2, 5).replace("\\n", "") | |
| yield processed_text | |
| RELEVANT_FIELDS = ['title', 'channel_title', 'category', 'tags', 'views', 'likes', 'dislikes'] | |
| PROMPT_TEMPLATE = "Video Information:\nTitle: {title}\nChannel: {channel_title}\nCategory: {category}\nTags: {tags}\nViews: {views}\nLikes: {likes}\nDislikes: {dislikes}\n\nComment:\n{comment_text}" | |
| def format_data_for_lm(example): | |
| try: | |
| metadata = {field: str(example.get(field, 'N/A')) for field in RELEVANT_FIELDS} | |
| metadata['comment_text'] = str(example.get('comment_text', '')) | |
| formatted_text = PROMPT_TEMPLATE.format(**metadata) | |
| return {"text": formatted_text} | |
| except Exception as e: | |
| st.error(f"Error formatting example: {e}") | |
| return {"text": ""} | |
| def build_service(): | |
| key = "AIzaSyB3hMSp3LgMbpr-gD-btHWeKAvf7PhrPiw" | |
| return build("youtube", "v3", developerKey=key) | |
| def get_video_info(video_id): | |
| api = build_service() | |
| response = api.videos().list(part="snippet,contentDetails,statistics", id=video_id).execute() | |
| lang = response['items'][0]['snippet'].get('defaultAudioLanguage', 'en') | |
| if lang[:2] != "en" : | |
| raise Exception(f"Language {lang}, not supported") | |
| video_info = { | |
| 'title': response['items'][0]['snippet']['title'], | |
| 'channel_title': response['items'][0]['snippet']['channelTitle'], | |
| 'category': response['items'][0]['snippet']['categoryId'], | |
| 'tags': '|'.join(response['items'][0]['snippet'].get('tags', [])), | |
| 'views': response['items'][0]['statistics']['viewCount'], | |
| 'likes': response['items'][0]['statistics'].get('likeCount', 0), | |
| 'dislikes': response['items'][0]['statistics'].get('dislikeCount', 0), | |
| } | |
| return video_info | |
| def load_model(): | |
| model_name = "Alekhon/gpt2-clown-commenter" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| return tokenizer, model | |
| def predict_comment(link): | |
| try: | |
| video_id = get_video_id(link) | |
| video_info = get_video_info(video_id) | |
| prompt = format_data_for_lm(video_info)['text'] | |
| return True, generate_answer(model, tokenizer, prompt) | |
| except Exception as e: | |
| return False, f"Error generating comment: {e}" | |
| st.markdown("# YouTube Comment Generator") | |
| st.markdown("### Generate comments using video metadata!") | |
| st.image(Image.open("hw4/clown.jpg"), width=400) | |
| tokenizer, model = load_model() | |
| link = st.text_input("Enter YouTube Link", placeholder="Paste URL here...") | |
| if st.button('Generate! 🤡'): | |
| if not link: | |
| st.warning("Please enter a YouTube link") | |
| else: | |
| success, result = predict_comment(link) | |
| if success: | |
| generating_placeholder = st.empty() | |
| generating_placeholder.status("Generating...") | |
| comment_placeholder = st.empty() | |
| final_text = "" | |
| for partial_text in result: | |
| final_text = partial_text | |
| comment_placeholder.markdown(f"**Comment:**\n{final_text}") | |
| generating_placeholder.empty() | |
| processed_text = remove_repeated_substrings(final_text, 2, 5).replace("\\n", "") | |
| comment_placeholder.success(f"**Final Comment:**\n{processed_text}") | |
| else: | |
| st.error(result) |