Alekhon commited on
Commit
cfd0432
·
verified ·
1 Parent(s): dd9ca70

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5
+ import pandas as pd
6
+ from googleapiclient.discovery import build
7
+ import re
8
+ import threading
9
+
10
+ @st.cache_resource
11
+ def get_video_id(url):
12
+ link_pref = "https://www.youtube.com/watch?v="
13
+ link_pref_2 = "https://youtu.be/"
14
+ if url.startswith(link_pref):
15
+ end = url.find("&")
16
+ return url[len(link_pref): end if end != -1 else None]
17
+ elif url.startswith(link_pref_2):
18
+ end = url.find("?")
19
+ return url[len(link_pref_2): end if end != -1 else None]
20
+ else:
21
+ raise Exception("YOU NEED TO PASTE YOUTUBE LINK 🤡🤡🤡!!! \nYOU PASTED " + link)
22
+
23
+ @st.cache_resource
24
+ def remove_repeated_substrings(s, n, k):
25
+ if k < 2 or n < 1 or len(s) < n * k:
26
+ return s
27
+
28
+ original = s
29
+ max_m = len(original) // k
30
+
31
+ for m in range(max_m, n - 1, -1):
32
+ pattern = re.compile(r"((.{" + str(m) + r"}))\1{" + str(k - 1) + r",}")
33
+ while True:
34
+ new_s, replacements = pattern.subn(r"\1", original)
35
+ if replacements == 0:
36
+ break
37
+ original = new_s
38
+
39
+ return original
40
+
41
+ def generate_answer(_model, _tokenizer, prompt_text, _device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), max_new_tokens=100, temperature=1.3, top_k=90, top_p=0.7, do_sample=True):
42
+ prompt_separator = "\n>>> Prompt:\n"
43
+ answer_separator = "\n>>> Answer:\n"
44
+ input_text = prompt_separator + prompt_text + answer_separator
45
+
46
+ inputs = _tokenizer(input_text, return_tensors="pt", truncation=True)
47
+ inputs = {k: v.to(_device) for k, v in inputs.items()}
48
+ input_ids = inputs["input_ids"]
49
+ input_length = input_ids.shape[1]
50
+ model_max_length = _tokenizer.model_max_length if hasattr(_tokenizer, 'model_max_length') else 1024
51
+
52
+ if input_length >= model_max_length - max_new_tokens:
53
+ allowed_input_length = model_max_length - max_new_tokens - 5
54
+ input_ids = input_ids[:, -allowed_input_length:]
55
+ inputs['input_ids'] = input_ids
56
+ inputs['attention_mask'] = inputs['attention_mask'][:, -allowed_input_length:]
57
+
58
+ streamer = TextIteratorStreamer(_tokenizer, skip_prompt=True, skip_special_tokens=True)
59
+ generation_kwargs = dict(
60
+ **inputs,
61
+ max_new_tokens=max_new_tokens,
62
+ do_sample=do_sample,
63
+ temperature=temperature,
64
+ top_k=top_k,
65
+ top_p=top_p,
66
+ pad_token_id=_tokenizer.pad_token_id,
67
+ eos_token_id=_tokenizer.eos_token_id,
68
+ streamer=streamer,
69
+ )
70
+
71
+ thread = threading.Thread(target=_model.generate, kwargs=generation_kwargs)
72
+ thread.start()
73
+
74
+ generated_text = ""
75
+ for new_text in streamer:
76
+ generated_text += new_text
77
+ yield generated_text
78
+
79
+ processed_text = remove_repeated_substrings(generated_text, 2, 5).replace("\\n", "")
80
+ yield processed_text
81
+
82
+ RELEVANT_FIELDS = ['title', 'channel_title', 'category', 'tags', 'views', 'likes', 'dislikes']
83
+ PROMPT_TEMPLATE = "Video Information:\nTitle: {title}\nChannel: {channel_title}\nCategory: {category}\nTags: {tags}\nViews: {views}\nLikes: {likes}\nDislikes: {dislikes}\n\nComment:\n{comment_text}"
84
+
85
+ @st.cache_resource
86
+ def format_data_for_lm(example):
87
+ try:
88
+ metadata = {field: str(example.get(field, 'N/A')) for field in RELEVANT_FIELDS}
89
+ metadata['comment_text'] = str(example.get('comment_text', ''))
90
+ formatted_text = PROMPT_TEMPLATE.format(**metadata)
91
+ return {"text": formatted_text}
92
+ except Exception as e:
93
+ st.error(f"Error formatting example: {e}")
94
+ return {"text": ""}
95
+
96
+ @st.cache_resource
97
+ def build_service():
98
+ key = "AIzaSyB3hMSp3LgMbpr-gD-btHWeKAvf7PhrPiw"
99
+ return build("youtube", "v3", developerKey=key)
100
+
101
+ @st.cache_resource
102
+ def get_video_info(video_id):
103
+ api = build_service()
104
+ response = api.videos().list(part="snippet,contentDetails,statistics", id=video_id).execute()
105
+ lang = response['items'][0]['snippet'].get('defaultAudioLanguage', 'en')
106
+ if lang[:2] != "en" :
107
+ raise Exception(f"Language {lang}, not supported")
108
+
109
+ video_info = {
110
+ 'title': response['items'][0]['snippet']['title'],
111
+ 'channel_title': response['items'][0]['snippet']['channelTitle'],
112
+ 'category': response['items'][0]['snippet']['categoryId'],
113
+ 'tags': '|'.join(response['items'][0]['snippet'].get('tags', [])),
114
+ 'views': response['items'][0]['statistics']['viewCount'],
115
+ 'likes': response['items'][0]['statistics'].get('likeCount', 0),
116
+ 'dislikes': response['items'][0]['statistics'].get('dislikeCount', 0),
117
+ }
118
+ return video_info
119
+
120
+ @st.cache_resource
121
+ def load_model():
122
+ model_name = "Alekhon/gpt2-clown-commenter"
123
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
124
+ model = AutoModelForCausalLM.from_pretrained(model_name)
125
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
+ model.to(device)
127
+ return tokenizer, model
128
+
129
+ def predict_comment(link):
130
+ try:
131
+ video_id = get_video_id(link)
132
+ video_info = get_video_info(video_id)
133
+ prompt = format_data_for_lm(video_info)['text']
134
+ return True, generate_answer(model, tokenizer, prompt)
135
+ except Exception as e:
136
+ return False, f"Error generating comment: {e}"
137
+
138
+ st.markdown("# YouTube Comment Generator")
139
+ st.markdown("### Generate comments using video metadata!")
140
+
141
+ st.image(Image.open("hw4/clown.jpg"), width=400)
142
+
143
+ tokenizer, model = load_model()
144
+
145
+ link = st.text_input("Enter YouTube Link", placeholder="Paste URL here...")
146
+ if st.button('Generate! 🤡'):
147
+ if not link:
148
+ st.warning("Please enter a YouTube link")
149
+ else:
150
+ success, result = predict_comment(link)
151
+ if success:
152
+ generating_placeholder = st.empty()
153
+ generating_placeholder.status("Generating...")
154
+ comment_placeholder = st.empty()
155
+ final_text = ""
156
+
157
+ for partial_text in result:
158
+ final_text = partial_text
159
+ comment_placeholder.markdown(f"**Comment:**\n{final_text}")
160
+
161
+ generating_placeholder.empty()
162
+ processed_text = remove_repeated_substrings(final_text, 2, 5).replace("\\n", "")
163
+ comment_placeholder.success(f"**Final Comment:**\n{processed_text}")
164
+ else:
165
+ st.error(result)