import streamlit as st from PIL import Image import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import pandas as pd from googleapiclient.discovery import build import re import threading from accelerate import init_empty_weights @st.cache_resource def get_video_id(url): link_pref = "https://www.youtube.com/watch?v=" link_pref_2 = "https://youtu.be/" if url.startswith(link_pref): end = url.find("&") return url[len(link_pref): end if end != -1 else None] elif url.startswith(link_pref_2): end = url.find("?") return url[len(link_pref_2): end if end != -1 else None] else: raise Exception("YOU NEED TO PASTE YOUTUBE LINK 🤡🤡🤡!!! \nYOU PASTED " + link) @st.cache_resource def remove_repeated_substrings(s, n, k): if k < 2 or n < 1 or len(s) < n * k: return s original = s max_m = len(original) // k for m in range(max_m, n - 1, -1): pattern = re.compile(r"((.{" + str(m) + r"}))\1{" + str(k - 1) + r",}") while True: new_s, replacements = pattern.subn(r"\1", original) if replacements == 0: break original = new_s return original def generate_answer(_model, _tokenizer, prompt_text, _device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), max_new_tokens=100, temperature=1.3, top_k=90, top_p=0.7, do_sample=True): prompt_separator = "\n>>> Prompt:\n" answer_separator = "\n>>> Answer:\n" input_text = prompt_separator + prompt_text + answer_separator inputs = _tokenizer(input_text, return_tensors="pt", truncation=True) inputs = {k: v.to(_device) for k, v in inputs.items()} input_ids = inputs["input_ids"] input_length = input_ids.shape[1] model_max_length = _tokenizer.model_max_length if hasattr(_tokenizer, 'model_max_length') else 1024 if input_length >= model_max_length - max_new_tokens: allowed_input_length = model_max_length - max_new_tokens - 5 input_ids = input_ids[:, -allowed_input_length:] inputs['input_ids'] = input_ids inputs['attention_mask'] = inputs['attention_mask'][:, -allowed_input_length:] streamer = TextIteratorStreamer(_tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, pad_token_id=_tokenizer.pad_token_id, eos_token_id=_tokenizer.eos_token_id, streamer=streamer, ) thread = threading.Thread(target=_model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text processed_text = remove_repeated_substrings(generated_text, 2, 5).replace("\\n", "") yield processed_text RELEVANT_FIELDS = ['title', 'channel_title', 'category', 'tags', 'views', 'likes', 'dislikes'] PROMPT_TEMPLATE = "Video Information:\nTitle: {title}\nChannel: {channel_title}\nCategory: {category}\nTags: {tags}\nViews: {views}\nLikes: {likes}\nDislikes: {dislikes}\n\nComment:\n{comment_text}" @st.cache_resource def format_data_for_lm(example): try: metadata = {field: str(example.get(field, 'N/A')) for field in RELEVANT_FIELDS} metadata['comment_text'] = str(example.get('comment_text', '')) formatted_text = PROMPT_TEMPLATE.format(**metadata) return {"text": formatted_text} except Exception as e: st.error(f"Error formatting example: {e}") return {"text": ""} @st.cache_resource def build_service(): key = "AIzaSyB3hMSp3LgMbpr-gD-btHWeKAvf7PhrPiw" return build("youtube", "v3", developerKey=key) @st.cache_resource def get_video_info(video_id): api = build_service() response = api.videos().list(part="snippet,contentDetails,statistics", id=video_id).execute() lang = response['items'][0]['snippet'].get('defaultAudioLanguage', 'en') if lang[:2] != "en" : raise Exception(f"Language {lang}, not supported") video_info = { 'title': response['items'][0]['snippet']['title'], 'channel_title': response['items'][0]['snippet']['channelTitle'], 'category': category_dict.get(int(response['items'][0]['snippet']['categoryId']), "Unknown category"), 'tags': '|'.join(response['items'][0]['snippet'].get('tags', [])), 'views': response['items'][0]['statistics']['viewCount'], 'likes': response['items'][0]['statistics'].get('likeCount', 0), 'dislikes': response['items'][0]['statistics'].get('dislikeCount', 0), } return video_info @st.cache_resource def load_model(): model_name = "Alekhon/gpt2-clown-commenter" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) return tokenizer, model def predict_comment(link): try: video_id = get_video_id(link) video_info = get_video_info(video_id) prompt = format_data_for_lm(video_info)['text'] return True, generate_answer(model, tokenizer, prompt) except Exception as e: return False, f"Error generating comment: {e}" st.markdown("# YouTube Comment Generator") st.markdown("### Generate comments using video metadata!") st.image(Image.open("clown.jpg"), width=400) category_dict = {1: 'Film & Animation', 2: 'Autos & Vehicles', 10: 'Music', 15: 'Pets & Animals', 17: 'Sports', 18: 'Short Movies', 19: 'Travel & Events', 20: 'Gaming', 21: 'Videoblogging', 22: 'People & Blogs', 23: 'Comedy', 24: 'Entertainment', 25: 'News & Politics', 26: 'Howto & Style', 27: 'Education', 28: 'Science & Technology', 29: 'Nonprofits & Activism', 30: 'Movies', 31: 'Anime/Animation', 32: 'Action/Adventure', 33: 'Classics', 34: 'Comedy', 35: 'Documentary', 36: 'Drama', 37: 'Family', 38: 'Foreign', 39: 'Horror', 40: 'Sci-Fi/Fantasy', 41: 'Thriller', 42: 'Shorts', 43: 'Shows', 44: 'Trailers'} tokenizer, model = load_model() link = st.text_input("Enter YouTube Link", placeholder="Paste URL here...") if st.button('Generate! 🤡'): if not link: st.warning("Please enter a YouTube link") else: success, result = predict_comment(link) if success: generating_placeholder = st.empty() generating_placeholder.status("Generating...") comment_placeholder = st.empty() final_text = "" for partial_text in result: final_text = partial_text comment_placeholder.markdown(f"**Comment:**\n{final_text}") generating_placeholder.empty() processed_text = remove_repeated_substrings(final_text, 2, 5).replace("\\n", "") comment_placeholder.success(f"**Final Comment:**\n{processed_text}") else: st.error(result)