Simplified interface
Browse files- app.py +27 -40
- infer.py +0 -90
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -16,7 +16,6 @@ from infer import (
|
|
| 16 |
find_answer_start,
|
| 17 |
get_noising_schedule,
|
| 18 |
noisify_answer,
|
| 19 |
-
generate_diffusion_text,
|
| 20 |
filter_logits,
|
| 21 |
confidence_guided_noising,
|
| 22 |
noisify_answer_without_remasking
|
|
@@ -39,17 +38,17 @@ rng = np.random.default_rng()
|
|
| 39 |
def generate_diffusion_text(input_ids, top_p, top_k):
|
| 40 |
with torch.no_grad():
|
| 41 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
| 42 |
-
|
|
|
|
| 43 |
logits = model(input_ids=input_tensor)["logits"]
|
| 44 |
-
|
|
|
|
| 45 |
logits = logits.clamp(min=-1e8, max=1e4)
|
| 46 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
| 47 |
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
| 48 |
-
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
|
| 49 |
-
assert (probs >= 0).all(), "Negative probs!"
|
| 50 |
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
|
| 51 |
-
|
| 52 |
-
# Extract confidence of selected tokens
|
| 53 |
conf = probs[range(len(sampled)), sampled].cpu().numpy()
|
| 54 |
return sampled, conf
|
| 55 |
|
|
@@ -79,10 +78,14 @@ def highlight_tokens(token_ids, answer_start, changed_indices, color):
|
|
| 79 |
highlighted.append(tok_str)
|
| 80 |
return "".join(highlighted)
|
| 81 |
|
| 82 |
-
def diffusion_chat(question,
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
if question.strip() == "":
|
| 88 |
question = "What do you know about the city of Amsterdam?"
|
|
@@ -111,6 +114,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
|
|
| 111 |
unmasked_mask = [False] * len(current_tokens)
|
| 112 |
|
| 113 |
for i in range(max_it):
|
|
|
|
| 114 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
|
| 115 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
| 116 |
|
|
@@ -133,25 +137,15 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
|
|
| 133 |
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
|
| 134 |
yield render_html("Stopped early", f"After {i+1} iterations.")
|
| 135 |
break
|
| 136 |
-
|
| 137 |
# NOISING
|
| 138 |
-
if i < max_it-1:
|
| 139 |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
elif use_permanent_unmasking:
|
| 146 |
-
noised_answer, just_noised_indices = noisify_answer_without_remasking(
|
| 147 |
-
current_tokens, answer_start, tokenizer, threshold=threshold,
|
| 148 |
-
noise_start=noise_start, unmasked_mask=unmasked_mask
|
| 149 |
-
)
|
| 150 |
-
else:
|
| 151 |
-
noised_answer, just_noised_indices = noisify_answer(
|
| 152 |
-
current_tokens, answer_start, tokenizer,
|
| 153 |
-
threshold=threshold, clustering=clustering, noise_start=noise_start
|
| 154 |
-
)
|
| 155 |
|
| 156 |
for idx in range(answer_start, len(current_tokens)):
|
| 157 |
if noised_answer[idx] != mask_token_id:
|
|
@@ -172,7 +166,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
|
|
| 172 |
final_ids = answer_ids
|
| 173 |
|
| 174 |
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
|
| 175 |
-
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
|
| 176 |
|
| 177 |
|
| 178 |
def is_running_on_spaces():
|
|
@@ -197,22 +191,15 @@ print("✅ Model loaded.")
|
|
| 197 |
vocab_size = len(tokenizer)
|
| 198 |
eos_token_id = tokenizer.eos_token_id
|
| 199 |
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
|
| 200 |
-
assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False)
|
| 201 |
|
| 202 |
demo = gr.Interface(
|
| 203 |
fn=diffusion_chat,
|
| 204 |
inputs=[
|
| 205 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
|
| 206 |
-
gr.
|
| 207 |
-
gr.Slider(
|
| 208 |
-
gr.Slider(
|
| 209 |
-
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
|
| 210 |
-
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),
|
| 211 |
-
gr.Checkbox(value=False, label="Use confidence-guided noising"),
|
| 212 |
-
gr.Checkbox(value=False, label="Use permanent unmasking"),
|
| 213 |
-
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
|
| 214 |
-
gr.Slider(1, 1000, value = 3, step = 1, label = "Top-p: ↑ = more random answers"),
|
| 215 |
-
gr.Slider(0.0, 1.0, value = 1.0, step = 0.01, label = "Top-k: ↑ = more random answers")
|
| 216 |
],
|
| 217 |
outputs=[gr.HTML(label="Diffusion Output")],
|
| 218 |
title="Diffusion Language Model Chat",
|
|
|
|
| 16 |
find_answer_start,
|
| 17 |
get_noising_schedule,
|
| 18 |
noisify_answer,
|
|
|
|
| 19 |
filter_logits,
|
| 20 |
confidence_guided_noising,
|
| 21 |
noisify_answer_without_remasking
|
|
|
|
| 38 |
def generate_diffusion_text(input_ids, top_p, top_k):
|
| 39 |
with torch.no_grad():
|
| 40 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
| 41 |
+
|
| 42 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
| 43 |
logits = model(input_ids=input_tensor)["logits"]
|
| 44 |
+
|
| 45 |
+
logits = filter_logits(logits, top_k=top_k, top_p=top_p)
|
| 46 |
logits = logits.clamp(min=-1e8, max=1e4)
|
| 47 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
| 48 |
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
| 49 |
+
# assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
|
| 50 |
+
# assert (probs >= 0).all(), "Negative probs!"
|
| 51 |
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
|
|
|
|
|
|
|
| 52 |
conf = probs[range(len(sampled)), sampled].cpu().numpy()
|
| 53 |
return sampled, conf
|
| 54 |
|
|
|
|
| 78 |
highlighted.append(tok_str)
|
| 79 |
return "".join(highlighted)
|
| 80 |
|
| 81 |
+
def diffusion_chat(question, noising, max_it, pause_length):
|
| 82 |
+
|
| 83 |
+
pause_length = 0
|
| 84 |
+
sharpness = 3.0
|
| 85 |
+
noise_start = 0.5
|
| 86 |
+
top_p = 1.0
|
| 87 |
+
top_k = 10
|
| 88 |
+
clustering = False
|
| 89 |
|
| 90 |
if question.strip() == "":
|
| 91 |
question = "What do you know about the city of Amsterdam?"
|
|
|
|
| 114 |
unmasked_mask = [False] * len(current_tokens)
|
| 115 |
|
| 116 |
for i in range(max_it):
|
| 117 |
+
|
| 118 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
|
| 119 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
| 120 |
|
|
|
|
| 137 |
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
|
| 138 |
yield render_html("Stopped early", f"After {i+1} iterations.")
|
| 139 |
break
|
| 140 |
+
|
| 141 |
# NOISING
|
| 142 |
+
if i < max_it-1 and noising:
|
| 143 |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
| 144 |
+
|
| 145 |
+
noised_answer, just_noised_indices = noisify_answer(
|
| 146 |
+
current_tokens, answer_start, tokenizer,
|
| 147 |
+
threshold=threshold, clustering=clustering, noise_start=noise_start
|
| 148 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
for idx in range(answer_start, len(current_tokens)):
|
| 151 |
if noised_answer[idx] != mask_token_id:
|
|
|
|
| 166 |
final_ids = answer_ids
|
| 167 |
|
| 168 |
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
|
| 169 |
+
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output) # type: ignore
|
| 170 |
|
| 171 |
|
| 172 |
def is_running_on_spaces():
|
|
|
|
| 191 |
vocab_size = len(tokenizer)
|
| 192 |
eos_token_id = tokenizer.eos_token_id
|
| 193 |
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
|
| 194 |
+
assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>\n", add_special_tokens=False)
|
| 195 |
|
| 196 |
demo = gr.Interface(
|
| 197 |
fn=diffusion_chat,
|
| 198 |
inputs=[
|
| 199 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
|
| 200 |
+
gr.Checkbox(label="Enable noising", value=True, info="If disabled, the model will not apply any intermediate noise."),
|
| 201 |
+
gr.Slider(1, 512, value=64, step=1, label="Increase the maximum number of iterations to run."),
|
| 202 |
+
gr.Slider(0, 5, value=0, step=0.01, label="Increase the pause between iterations to visualize the process.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
],
|
| 204 |
outputs=[gr.HTML(label="Diffusion Output")],
|
| 205 |
title="Diffusion Language Model Chat",
|
infer.py
CHANGED
|
@@ -190,26 +190,6 @@ def confidence_guided_noising(input_ids, answer_start, tokenizer, confidences, n
|
|
| 190 |
noised_indices = sorted(noised_indices)
|
| 191 |
return noised, noised_indices
|
| 192 |
|
| 193 |
-
def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0,
|
| 194 |
-
eos_token_id=None, eos_boost=0.0):
|
| 195 |
-
model.eval()
|
| 196 |
-
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 197 |
-
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
| 198 |
-
logits = model(input_ids=input_tensor)["logits"] # (1, seq_len, vocab_size)
|
| 199 |
-
|
| 200 |
-
# Optionally boost or suppress EOS token
|
| 201 |
-
if eos_token_id is not None and eos_boost != 0.0:
|
| 202 |
-
logits[:, :, eos_token_id] += eos_boost
|
| 203 |
-
|
| 204 |
-
# Filter and sample
|
| 205 |
-
filtered_logits = filter_logits(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
| 206 |
-
probs = F.softmax(filtered_logits, dim=-1).squeeze() # (seq_len, vocab_size)
|
| 207 |
-
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
| 208 |
-
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
| 209 |
-
confidences = probs.gather(1, sampled.unsqueeze(-1)).squeeze(-1)
|
| 210 |
-
|
| 211 |
-
return input_ids[:answer_start] + sampled[answer_start:].tolist(), confidences
|
| 212 |
-
|
| 213 |
|
| 214 |
def calculate_answer_perplexity(prompt, answer, model_name='gpt2-large'):
|
| 215 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
@@ -277,73 +257,3 @@ def save_html_colored_output(filename, html_content):
|
|
| 277 |
</body>
|
| 278 |
</html>
|
| 279 |
""")
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
|
| 283 |
-
noising_sharpness=5.0, max_length=256, top_k=100, top_p=1.0,
|
| 284 |
-
temperature=1.0, eos_token_id = None, eos_boost = 0.0) -> str:
|
| 285 |
-
|
| 286 |
-
if eos_token_id is None:
|
| 287 |
-
eos_token_id = tokenizer.eos_token_id
|
| 288 |
-
# Format prompt with LLaMA 3 chat template
|
| 289 |
-
prompt = (
|
| 290 |
-
"<|begin_of_text|>\n"
|
| 291 |
-
"<|start_header_id|>system<|end_header_id|>\n"
|
| 292 |
-
"You are a helpful assistant.\n"
|
| 293 |
-
"<|eot_id|>\n"
|
| 294 |
-
"<|start_header_id|>user<|end_header_id|>\n"
|
| 295 |
-
f"{question.strip()}\n"
|
| 296 |
-
"<|start_header_id|>assistant<|end_header_id|>\n"
|
| 297 |
-
)
|
| 298 |
-
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
| 299 |
-
marker = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>\n", add_special_tokens=False)
|
| 300 |
-
|
| 301 |
-
def find_answer_start(ids, marker):
|
| 302 |
-
for i in range(len(ids) - len(marker) + 1):
|
| 303 |
-
if ids[i:i+len(marker)] == marker:
|
| 304 |
-
return i + len(marker)
|
| 305 |
-
return None
|
| 306 |
-
|
| 307 |
-
answer_start = find_answer_start(input_ids, marker)
|
| 308 |
-
if answer_start is None:
|
| 309 |
-
raise ValueError("Assistant marker not found in prompt.")
|
| 310 |
-
|
| 311 |
-
# Pad to max length
|
| 312 |
-
pad_token = tokenizer.eos_token_id
|
| 313 |
-
mask_token = tokenizer.encode("MASK", add_special_tokens=False)[0]
|
| 314 |
-
input_ids = input_ids[:max_length]
|
| 315 |
-
if len(input_ids) < max_length:
|
| 316 |
-
input_ids += [mask_token] * (max_length - len(input_ids))
|
| 317 |
-
|
| 318 |
-
ori_tokens = input_ids
|
| 319 |
-
current_tokens = noisify_answer(ori_tokens, answer_start, threshold=1.0, mask_token_id=mask_token)
|
| 320 |
-
|
| 321 |
-
last_tokens = []
|
| 322 |
-
for step in range(max_it):
|
| 323 |
-
# Generate a new prediction
|
| 324 |
-
current_tokens, confidence_scores = generate_diffusion_text(
|
| 325 |
-
model, current_tokens, answer_start,
|
| 326 |
-
top_k=top_k, top_p=top_p, temperature=temperature,
|
| 327 |
-
eos_token_id=eos_token_id, eos_boost=eos_boost
|
| 328 |
-
)
|
| 329 |
-
|
| 330 |
-
# Display for debugging / tracking
|
| 331 |
-
display_diffusion_output(
|
| 332 |
-
step, max_it, question,
|
| 333 |
-
ori_tokens, current_tokens, confidence_scores,
|
| 334 |
-
answer_start, tokenizer
|
| 335 |
-
)
|
| 336 |
-
|
| 337 |
-
# Early stopping
|
| 338 |
-
last_tokens.append(current_tokens)
|
| 339 |
-
if len(last_tokens) > 4:
|
| 340 |
-
last_tokens.pop(0)
|
| 341 |
-
if all(t == last_tokens[0] for t in last_tokens):
|
| 342 |
-
break
|
| 343 |
-
|
| 344 |
-
# Re-apply noise for next iteration
|
| 345 |
-
if step < max_it - 1:
|
| 346 |
-
threshold = noise_start * get_noising_schedule(step, max_it, sharpness=noising_sharpness)
|
| 347 |
-
current_tokens = noisify_answer(current_tokens, answer_start, threshold=threshold, mask_token_id=mask_token)
|
| 348 |
-
|
| 349 |
-
return tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).strip()
|
|
|
|
| 190 |
noised_indices = sorted(noised_indices)
|
| 191 |
return noised, noised_indices
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
def calculate_answer_perplexity(prompt, answer, model_name='gpt2-large'):
|
| 195 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
| 257 |
</body>
|
| 258 |
</html>
|
| 259 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -7,3 +7,4 @@ gradio>=4.10.0
|
|
| 7 |
numpy
|
| 8 |
load_dotenv
|
| 9 |
ipython
|
|
|
|
|
|
| 7 |
numpy
|
| 8 |
load_dotenv
|
| 9 |
ipython
|
| 10 |
+
spaces
|