import streamlit as st from PIL import Image import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import pandas as pd from googleapiclient.discovery import build import re import threading @st.cache_resource def get_video_id(url): link_pref = "https://www.youtube.com/watch?v=" link_pref_2 = "https://youtu.be/" if url.startswith(link_pref): end = url.find("&") return url[len(link_pref): end if end != -1 else None] elif url.startswith(link_pref_2): end = url.find("?") return url[len(link_pref_2): end if end != -1 else None] else: raise Exception("YOU NEED TO PASTE YOUTUBE LINK 🤡🤡🤡!!! \nYOU PASTED " + link) @st.cache_resource def remove_repeated_substrings(s, n, k): if k < 2 or n < 1 or len(s) < n * k: return s original = s max_m = len(original) // k for m in range(max_m, n - 1, -1): pattern = re.compile(r"((.{" + str(m) + r"}))\1{" + str(k - 1) + r",}") while True: new_s, replacements = pattern.subn(r"\1", original) if replacements == 0: break original = new_s return original def generate_answer(_model, _tokenizer, prompt_text, _device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), max_new_tokens=100, temperature=1.3, top_k=90, top_p=0.7, do_sample=True): prompt_separator = "\n>>> Prompt:\n" answer_separator = "\n>>> Answer:\n" input_text = prompt_separator + prompt_text + answer_separator inputs = _tokenizer(input_text, return_tensors="pt", truncation=True) inputs = {k: v.to(_device) for k, v in inputs.items()} input_ids = inputs["input_ids"] input_length = input_ids.shape[1] model_max_length = _tokenizer.model_max_length if hasattr(_tokenizer, 'model_max_length') else 1024 if input_length >= model_max_length - max_new_tokens: allowed_input_length = model_max_length - max_new_tokens - 5 input_ids = input_ids[:, -allowed_input_length:] inputs['input_ids'] = input_ids inputs['attention_mask'] = inputs['attention_mask'][:, -allowed_input_length:] streamer = TextIteratorStreamer(_tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, pad_token_id=_tokenizer.pad_token_id, eos_token_id=_tokenizer.eos_token_id, streamer=streamer, ) thread = threading.Thread(target=_model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text processed_text = remove_repeated_substrings(generated_text, 2, 5).replace("\\n", "") yield processed_text RELEVANT_FIELDS = ['title', 'channel_title', 'category', 'tags', 'views', 'likes', 'dislikes'] PROMPT_TEMPLATE = "Video Information:\nTitle: {title}\nChannel: {channel_title}\nCategory: {category}\nTags: {tags}\nViews: {views}\nLikes: {likes}\nDislikes: {dislikes}\n\nComment:\n{comment_text}" @st.cache_resource def format_data_for_lm(example): try: metadata = {field: str(example.get(field, 'N/A')) for field in RELEVANT_FIELDS} metadata['comment_text'] = str(example.get('comment_text', '')) formatted_text = PROMPT_TEMPLATE.format(**metadata) return {"text": formatted_text} except Exception as e: st.error(f"Error formatting example: {e}") return {"text": ""} @st.cache_resource def build_service(): key = "AIzaSyB3hMSp3LgMbpr-gD-btHWeKAvf7PhrPiw" return build("youtube", "v3", developerKey=key) @st.cache_resource def get_video_info(video_id): api = build_service() response = api.videos().list(part="snippet,contentDetails,statistics", id=video_id).execute() lang = response['items'][0]['snippet'].get('defaultAudioLanguage', 'en') if lang[:2] != "en" : raise Exception(f"Language {lang}, not supported") video_info = { 'title': response['items'][0]['snippet']['title'], 'channel_title': response['items'][0]['snippet']['channelTitle'], 'category': response['items'][0]['snippet']['categoryId'], 'tags': '|'.join(response['items'][0]['snippet'].get('tags', [])), 'views': response['items'][0]['statistics']['viewCount'], 'likes': response['items'][0]['statistics'].get('likeCount', 0), 'dislikes': response['items'][0]['statistics'].get('dislikeCount', 0), } return video_info @st.cache_resource def load_model(): model_name = "Alekhon/gpt2-clown-commenter" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) return tokenizer, model def predict_comment(link): try: video_id = get_video_id(link) video_info = get_video_info(video_id) prompt = format_data_for_lm(video_info)['text'] return True, generate_answer(model, tokenizer, prompt) except Exception as e: return False, f"Error generating comment: {e}" st.markdown("# YouTube Comment Generator") st.markdown("### Generate comments using video metadata!") st.image(Image.open("hw4/clown.jpg"), width=400) tokenizer, model = load_model() link = st.text_input("Enter YouTube Link", placeholder="Paste URL here...") if st.button('Generate! 🤡'): if not link: st.warning("Please enter a YouTube link") else: success, result = predict_comment(link) if success: generating_placeholder = st.empty() generating_placeholder.status("Generating...") comment_placeholder = st.empty() final_text = "" for partial_text in result: final_text = partial_text comment_placeholder.markdown(f"**Comment:**\n{final_text}") generating_placeholder.empty() processed_text = remove_repeated_substrings(final_text, 2, 5).replace("\\n", "") comment_placeholder.success(f"**Final Comment:**\n{processed_text}") else: st.error(result)