Spaces:
Sleeping
Sleeping
File size: 6,391 Bytes
cfd0432 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import streamlit as st
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import pandas as pd
from googleapiclient.discovery import build
import re
import threading
@st.cache_resource
def get_video_id(url):
link_pref = "https://www.youtube.com/watch?v="
link_pref_2 = "https://youtu.be/"
if url.startswith(link_pref):
end = url.find("&")
return url[len(link_pref): end if end != -1 else None]
elif url.startswith(link_pref_2):
end = url.find("?")
return url[len(link_pref_2): end if end != -1 else None]
else:
raise Exception("YOU NEED TO PASTE YOUTUBE LINK 🤡🤡🤡!!! \nYOU PASTED " + link)
@st.cache_resource
def remove_repeated_substrings(s, n, k):
if k < 2 or n < 1 or len(s) < n * k:
return s
original = s
max_m = len(original) // k
for m in range(max_m, n - 1, -1):
pattern = re.compile(r"((.{" + str(m) + r"}))\1{" + str(k - 1) + r",}")
while True:
new_s, replacements = pattern.subn(r"\1", original)
if replacements == 0:
break
original = new_s
return original
def generate_answer(_model, _tokenizer, prompt_text, _device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), max_new_tokens=100, temperature=1.3, top_k=90, top_p=0.7, do_sample=True):
prompt_separator = "\n>>> Prompt:\n"
answer_separator = "\n>>> Answer:\n"
input_text = prompt_separator + prompt_text + answer_separator
inputs = _tokenizer(input_text, return_tensors="pt", truncation=True)
inputs = {k: v.to(_device) for k, v in inputs.items()}
input_ids = inputs["input_ids"]
input_length = input_ids.shape[1]
model_max_length = _tokenizer.model_max_length if hasattr(_tokenizer, 'model_max_length') else 1024
if input_length >= model_max_length - max_new_tokens:
allowed_input_length = model_max_length - max_new_tokens - 5
input_ids = input_ids[:, -allowed_input_length:]
inputs['input_ids'] = input_ids
inputs['attention_mask'] = inputs['attention_mask'][:, -allowed_input_length:]
streamer = TextIteratorStreamer(_tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
pad_token_id=_tokenizer.pad_token_id,
eos_token_id=_tokenizer.eos_token_id,
streamer=streamer,
)
thread = threading.Thread(target=_model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
processed_text = remove_repeated_substrings(generated_text, 2, 5).replace("\\n", "")
yield processed_text
RELEVANT_FIELDS = ['title', 'channel_title', 'category', 'tags', 'views', 'likes', 'dislikes']
PROMPT_TEMPLATE = "Video Information:\nTitle: {title}\nChannel: {channel_title}\nCategory: {category}\nTags: {tags}\nViews: {views}\nLikes: {likes}\nDislikes: {dislikes}\n\nComment:\n{comment_text}"
@st.cache_resource
def format_data_for_lm(example):
try:
metadata = {field: str(example.get(field, 'N/A')) for field in RELEVANT_FIELDS}
metadata['comment_text'] = str(example.get('comment_text', ''))
formatted_text = PROMPT_TEMPLATE.format(**metadata)
return {"text": formatted_text}
except Exception as e:
st.error(f"Error formatting example: {e}")
return {"text": ""}
@st.cache_resource
def build_service():
key = "AIzaSyB3hMSp3LgMbpr-gD-btHWeKAvf7PhrPiw"
return build("youtube", "v3", developerKey=key)
@st.cache_resource
def get_video_info(video_id):
api = build_service()
response = api.videos().list(part="snippet,contentDetails,statistics", id=video_id).execute()
lang = response['items'][0]['snippet'].get('defaultAudioLanguage', 'en')
if lang[:2] != "en" :
raise Exception(f"Language {lang}, not supported")
video_info = {
'title': response['items'][0]['snippet']['title'],
'channel_title': response['items'][0]['snippet']['channelTitle'],
'category': response['items'][0]['snippet']['categoryId'],
'tags': '|'.join(response['items'][0]['snippet'].get('tags', [])),
'views': response['items'][0]['statistics']['viewCount'],
'likes': response['items'][0]['statistics'].get('likeCount', 0),
'dislikes': response['items'][0]['statistics'].get('dislikeCount', 0),
}
return video_info
@st.cache_resource
def load_model():
model_name = "Alekhon/gpt2-clown-commenter"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return tokenizer, model
def predict_comment(link):
try:
video_id = get_video_id(link)
video_info = get_video_info(video_id)
prompt = format_data_for_lm(video_info)['text']
return True, generate_answer(model, tokenizer, prompt)
except Exception as e:
return False, f"Error generating comment: {e}"
st.markdown("# YouTube Comment Generator")
st.markdown("### Generate comments using video metadata!")
st.image(Image.open("hw4/clown.jpg"), width=400)
tokenizer, model = load_model()
link = st.text_input("Enter YouTube Link", placeholder="Paste URL here...")
if st.button('Generate! 🤡'):
if not link:
st.warning("Please enter a YouTube link")
else:
success, result = predict_comment(link)
if success:
generating_placeholder = st.empty()
generating_placeholder.status("Generating...")
comment_placeholder = st.empty()
final_text = ""
for partial_text in result:
final_text = partial_text
comment_placeholder.markdown(f"**Comment:**\n{final_text}")
generating_placeholder.empty()
processed_text = remove_repeated_substrings(final_text, 2, 5).replace("\\n", "")
comment_placeholder.success(f"**Final Comment:**\n{processed_text}")
else:
st.error(result) |