File size: 6,391 Bytes
cfd0432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import streamlit as st
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import pandas as pd
from googleapiclient.discovery import build
import re
import threading

@st.cache_resource
def get_video_id(url):
    link_pref = "https://www.youtube.com/watch?v="
    link_pref_2 = "https://youtu.be/"
    if url.startswith(link_pref):
        end = url.find("&")
        return url[len(link_pref): end if end != -1 else None]
    elif url.startswith(link_pref_2):
        end = url.find("?")
        return url[len(link_pref_2): end if end != -1 else None]
    else:
        raise Exception("YOU NEED TO PASTE YOUTUBE LINK 🤡🤡🤡!!! \nYOU PASTED " + link)

@st.cache_resource
def remove_repeated_substrings(s, n, k):
    if k < 2 or n < 1 or len(s) < n * k:
        return s
    
    original = s
    max_m = len(original) // k
    
    for m in range(max_m, n - 1, -1):
        pattern = re.compile(r"((.{" + str(m) + r"}))\1{" + str(k - 1) + r",}")
        while True:
            new_s, replacements = pattern.subn(r"\1", original)
            if replacements == 0:
                break
            original = new_s
    
    return original

def generate_answer(_model, _tokenizer, prompt_text, _device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), max_new_tokens=100, temperature=1.3, top_k=90, top_p=0.7, do_sample=True):
    prompt_separator = "\n>>> Prompt:\n"
    answer_separator = "\n>>> Answer:\n"
    input_text = prompt_separator + prompt_text + answer_separator

    inputs = _tokenizer(input_text, return_tensors="pt", truncation=True)
    inputs = {k: v.to(_device) for k, v in inputs.items()}
    input_ids = inputs["input_ids"]
    input_length = input_ids.shape[1]
    model_max_length = _tokenizer.model_max_length if hasattr(_tokenizer, 'model_max_length') else 1024

    if input_length >= model_max_length - max_new_tokens:
        allowed_input_length = model_max_length - max_new_tokens - 5
        input_ids = input_ids[:, -allowed_input_length:]
        inputs['input_ids'] = input_ids
        inputs['attention_mask'] = inputs['attention_mask'][:, -allowed_input_length:]

    streamer = TextIteratorStreamer(_tokenizer, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        pad_token_id=_tokenizer.pad_token_id,
        eos_token_id=_tokenizer.eos_token_id,
        streamer=streamer,
    )

    thread = threading.Thread(target=_model.generate, kwargs=generation_kwargs)
    thread.start()

    generated_text = ""
    for new_text in streamer:
        generated_text += new_text
        yield generated_text

    processed_text = remove_repeated_substrings(generated_text, 2, 5).replace("\\n", "")
    yield processed_text

RELEVANT_FIELDS = ['title', 'channel_title', 'category', 'tags', 'views', 'likes', 'dislikes']
PROMPT_TEMPLATE = "Video Information:\nTitle: {title}\nChannel: {channel_title}\nCategory: {category}\nTags: {tags}\nViews: {views}\nLikes: {likes}\nDislikes: {dislikes}\n\nComment:\n{comment_text}"

@st.cache_resource
def format_data_for_lm(example):
    try:
        metadata = {field: str(example.get(field, 'N/A')) for field in RELEVANT_FIELDS}
        metadata['comment_text'] = str(example.get('comment_text', ''))
        formatted_text = PROMPT_TEMPLATE.format(**metadata)
        return {"text": formatted_text}
    except Exception as e:
        st.error(f"Error formatting example: {e}")
        return {"text": ""}

@st.cache_resource
def build_service():
    key = "AIzaSyB3hMSp3LgMbpr-gD-btHWeKAvf7PhrPiw"
    return build("youtube", "v3", developerKey=key)

@st.cache_resource
def get_video_info(video_id):
    api = build_service()
    response = api.videos().list(part="snippet,contentDetails,statistics", id=video_id).execute()
    lang = response['items'][0]['snippet'].get('defaultAudioLanguage', 'en')
    if lang[:2] != "en" :
        raise Exception(f"Language {lang}, not supported")
    
    video_info = {
        'title': response['items'][0]['snippet']['title'],
        'channel_title': response['items'][0]['snippet']['channelTitle'],
        'category': response['items'][0]['snippet']['categoryId'],
        'tags': '|'.join(response['items'][0]['snippet'].get('tags', [])),
        'views': response['items'][0]['statistics']['viewCount'],
        'likes': response['items'][0]['statistics'].get('likeCount', 0),
        'dislikes': response['items'][0]['statistics'].get('dislikeCount', 0),
    }
    return video_info

@st.cache_resource
def load_model():
    model_name = "Alekhon/gpt2-clown-commenter"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return tokenizer, model

def predict_comment(link):
    try:
        video_id = get_video_id(link)
        video_info = get_video_info(video_id)
        prompt = format_data_for_lm(video_info)['text']
        return True, generate_answer(model, tokenizer, prompt)
    except Exception as e:
        return False, f"Error generating comment: {e}"

st.markdown("# YouTube Comment Generator")
st.markdown("### Generate comments using video metadata!")

st.image(Image.open("hw4/clown.jpg"), width=400)

tokenizer, model = load_model()

link = st.text_input("Enter YouTube Link", placeholder="Paste URL here...")
if st.button('Generate! 🤡'):
    if not link:
        st.warning("Please enter a YouTube link")
    else:
        success, result = predict_comment(link)
        if success:
            generating_placeholder = st.empty()
            generating_placeholder.status("Generating...")
            comment_placeholder = st.empty()
            final_text = ""
            
            for partial_text in result:
                final_text = partial_text
                comment_placeholder.markdown(f"**Comment:**\n{final_text}")
            
            generating_placeholder.empty()
            processed_text = remove_repeated_substrings(final_text, 2, 5).replace("\\n", "")
            comment_placeholder.success(f"**Final Comment:**\n{processed_text}")
        else:
            st.error(result)