Alekhon's picture
Create app.py
cfd0432 verified
raw
history blame
6.39 kB
import streamlit as st
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import pandas as pd
from googleapiclient.discovery import build
import re
import threading
@st.cache_resource
def get_video_id(url):
link_pref = "https://www.youtube.com/watch?v="
link_pref_2 = "https://youtu.be/"
if url.startswith(link_pref):
end = url.find("&")
return url[len(link_pref): end if end != -1 else None]
elif url.startswith(link_pref_2):
end = url.find("?")
return url[len(link_pref_2): end if end != -1 else None]
else:
raise Exception("YOU NEED TO PASTE YOUTUBE LINK 🤡🤡🤡!!! \nYOU PASTED " + link)
@st.cache_resource
def remove_repeated_substrings(s, n, k):
if k < 2 or n < 1 or len(s) < n * k:
return s
original = s
max_m = len(original) // k
for m in range(max_m, n - 1, -1):
pattern = re.compile(r"((.{" + str(m) + r"}))\1{" + str(k - 1) + r",}")
while True:
new_s, replacements = pattern.subn(r"\1", original)
if replacements == 0:
break
original = new_s
return original
def generate_answer(_model, _tokenizer, prompt_text, _device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), max_new_tokens=100, temperature=1.3, top_k=90, top_p=0.7, do_sample=True):
prompt_separator = "\n>>> Prompt:\n"
answer_separator = "\n>>> Answer:\n"
input_text = prompt_separator + prompt_text + answer_separator
inputs = _tokenizer(input_text, return_tensors="pt", truncation=True)
inputs = {k: v.to(_device) for k, v in inputs.items()}
input_ids = inputs["input_ids"]
input_length = input_ids.shape[1]
model_max_length = _tokenizer.model_max_length if hasattr(_tokenizer, 'model_max_length') else 1024
if input_length >= model_max_length - max_new_tokens:
allowed_input_length = model_max_length - max_new_tokens - 5
input_ids = input_ids[:, -allowed_input_length:]
inputs['input_ids'] = input_ids
inputs['attention_mask'] = inputs['attention_mask'][:, -allowed_input_length:]
streamer = TextIteratorStreamer(_tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
pad_token_id=_tokenizer.pad_token_id,
eos_token_id=_tokenizer.eos_token_id,
streamer=streamer,
)
thread = threading.Thread(target=_model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
processed_text = remove_repeated_substrings(generated_text, 2, 5).replace("\\n", "")
yield processed_text
RELEVANT_FIELDS = ['title', 'channel_title', 'category', 'tags', 'views', 'likes', 'dislikes']
PROMPT_TEMPLATE = "Video Information:\nTitle: {title}\nChannel: {channel_title}\nCategory: {category}\nTags: {tags}\nViews: {views}\nLikes: {likes}\nDislikes: {dislikes}\n\nComment:\n{comment_text}"
@st.cache_resource
def format_data_for_lm(example):
try:
metadata = {field: str(example.get(field, 'N/A')) for field in RELEVANT_FIELDS}
metadata['comment_text'] = str(example.get('comment_text', ''))
formatted_text = PROMPT_TEMPLATE.format(**metadata)
return {"text": formatted_text}
except Exception as e:
st.error(f"Error formatting example: {e}")
return {"text": ""}
@st.cache_resource
def build_service():
key = "AIzaSyB3hMSp3LgMbpr-gD-btHWeKAvf7PhrPiw"
return build("youtube", "v3", developerKey=key)
@st.cache_resource
def get_video_info(video_id):
api = build_service()
response = api.videos().list(part="snippet,contentDetails,statistics", id=video_id).execute()
lang = response['items'][0]['snippet'].get('defaultAudioLanguage', 'en')
if lang[:2] != "en" :
raise Exception(f"Language {lang}, not supported")
video_info = {
'title': response['items'][0]['snippet']['title'],
'channel_title': response['items'][0]['snippet']['channelTitle'],
'category': response['items'][0]['snippet']['categoryId'],
'tags': '|'.join(response['items'][0]['snippet'].get('tags', [])),
'views': response['items'][0]['statistics']['viewCount'],
'likes': response['items'][0]['statistics'].get('likeCount', 0),
'dislikes': response['items'][0]['statistics'].get('dislikeCount', 0),
}
return video_info
@st.cache_resource
def load_model():
model_name = "Alekhon/gpt2-clown-commenter"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return tokenizer, model
def predict_comment(link):
try:
video_id = get_video_id(link)
video_info = get_video_info(video_id)
prompt = format_data_for_lm(video_info)['text']
return True, generate_answer(model, tokenizer, prompt)
except Exception as e:
return False, f"Error generating comment: {e}"
st.markdown("# YouTube Comment Generator")
st.markdown("### Generate comments using video metadata!")
st.image(Image.open("hw4/clown.jpg"), width=400)
tokenizer, model = load_model()
link = st.text_input("Enter YouTube Link", placeholder="Paste URL here...")
if st.button('Generate! 🤡'):
if not link:
st.warning("Please enter a YouTube link")
else:
success, result = predict_comment(link)
if success:
generating_placeholder = st.empty()
generating_placeholder.status("Generating...")
comment_placeholder = st.empty()
final_text = ""
for partial_text in result:
final_text = partial_text
comment_placeholder.markdown(f"**Comment:**\n{final_text}")
generating_placeholder.empty()
processed_text = remove_repeated_substrings(final_text, 2, 5).replace("\\n", "")
comment_placeholder.success(f"**Final Comment:**\n{processed_text}")
else:
st.error(result)