Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| import torch | |
| from PIL import Image | |
| import re | |
| import base64 | |
| import gradio as gr | |
| from transformers import DonutProcessor, VisionEncoderDecoderModel | |
| # Task prompt and model path | |
| task_prompt = f"<s_cord-v2>" # decoder ์ง๋ฌธ. ํ์ฑํด๋ผ.๋ผ๋ ์๋ฏธ. | |
| pretrained_path = "sma1-rmarud/donut-cord-v2-menu-sample-demo" | |
| # Load pretrained processor and model | |
| processor = DonutProcessor.from_pretrained(pretrained_path) | |
| pretrained_model = VisionEncoderDecoderModel.from_pretrained(pretrained_path) | |
| device = torch.device("cpu") # CPU ์ฌ์ฉ | |
| pretrained_model.to(device) | |
| pretrained_model = pretrained_model.float() | |
| pretrained_model.eval() | |
| # Function to convert tokenized output to JSON format | |
| def token2json(tokens, is_inner_value=False): | |
| output = dict() | |
| while tokens: | |
| start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE) | |
| if start_token is None: | |
| break | |
| key = start_token.group(1) | |
| end_token = re.search(fr"</s_{key}>", tokens, re.IGNORECASE) | |
| start_token = start_token.group() | |
| if end_token is None: | |
| tokens = tokens.replace(start_token, "") | |
| else: | |
| end_token = end_token.group() | |
| start_token_escaped = re.escape(start_token) | |
| end_token_escaped = re.escape(end_token) | |
| content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE) | |
| if content is not None: | |
| content = content.group(1).strip() | |
| if r"<s_" in content and r"</s_" in content: # non-leaf node | |
| value = token2json(content, is_inner_value=True) | |
| if value: | |
| if len(value) == 1: | |
| value = value[0] | |
| output[key] = value | |
| else: # leaf nodes | |
| output[key] = [] | |
| for leaf in content.split(r"<sep/>"): | |
| leaf = leaf.strip() | |
| output[key].append(leaf) | |
| if len(output[key]) == 1: | |
| output[key] = output[key][0] | |
| tokens = tokens[tokens.find(end_token) + len(end_token):].strip() | |
| if tokens[:6] == r"<sep/>": # non-leaf nodes | |
| return [output] + token2json(tokens[6:], is_inner_value=True) | |
| if len(output): | |
| return [output] if is_inner_value else output | |
| else: | |
| return [] if is_inner_value else {"text_sequence": tokens} | |
| # Gradio demo process function | |
| def demo_process(files): | |
| global pretrained_model, task_prompt, device | |
| results = [] | |
| for file in files: | |
| input_img = Image.open(file).convert("RGB") | |
| input_img = input_img.resize((960, 640)) | |
| pixel_values = processor(input_img, return_tensors="pt", padding=True).pixel_values.to(device) | |
| pixel_values = pixel_values.float() | |
| decoder_input_ids = torch.full((1, 1), pretrained_model.config.decoder_start_token_id, device=device) | |
| outputs = pretrained_model.generate(pixel_values, | |
| decoder_input_ids=decoder_input_ids, | |
| max_length=pretrained_model.config.decoder.max_length, | |
| early_stopping=True, | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| num_beams=1, | |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
| return_dict_in_generate=True,) | |
| predictions = [] | |
| for seq in processor.tokenizer.batch_decode(outputs.sequences): | |
| seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") | |
| seq = re.sub(r"<.*?>", "", seq, count=1).strip() | |
| predictions.append(seq) | |
| results.append(token2json(predictions[0])) | |
| return results | |
| # Base64 encode the SVG background | |
| sprinkle_svg = """ | |
| <svg id="sprinkle-pattern" xmlns="http://www.w3.org/2000/svg" width="500" height="500"> | |
| <defs> | |
| <pattern id="sprinkles" x="0" y="0" width="40" height="40" patternUnits="userSpaceOnUse"> | |
| <rect x="10" y="10" width="10" height="3" rx="1.5" transform="rotate(45, 15, 11.5)" fill="#FF5252"/> | |
| <rect x="20" y="20" width="10" height="3" rx="1.5" transform="rotate(-30, 25, 21.5)" fill="#FFD740"/> | |
| <rect x="30" y="30" width="10" height="3" rx="1.5" transform="rotate(60, 35, 31.5)" fill="#40C4FF"/> | |
| <rect x="10" y="30" width="10" height="3" rx="1.5" transform="rotate(10, 15, 31.5)" fill="#69F0AE"/> | |
| <rect x="50" y="10" width="10" height="3" rx="1.5" transform="rotate(120, 55, 11.5)" fill="#EA80FC"/> | |
| </pattern> | |
| </defs> | |
| <rect width="100%" height="100%" fill="#FFFFFF"/> | |
| <rect width="100%" height="100%" fill="url(#sprinkles)" opacity="0.6"/> | |
| </svg> | |
| """ | |
| encoded_svg = base64.b64encode(sprinkle_svg.encode('utf-8')).decode('ascii') | |
| background_url = f"data:image/svg+xml;base64,{encoded_svg}" | |
| # Gradio Interface | |
| demo = gr.Interface( | |
| fn=demo_process, | |
| inputs=gr.File(file_types=["image"], label="Upload multiple images", file_count="multiple"), | |
| outputs=gr.JSON(), # json list output | |
| description="<div class='white-title'>Donut ๐ฉ Demonstration</div>", # Title for demo | |
| theme="soft", | |
| css=f""" | |
| .gradio-container {{ | |
| background-color: #ffffff; | |
| background-image: url('{background_url}'); | |
| background-size: cover; | |
| background-repeat: no-repeat; | |
| background-position: center; | |
| }} | |
| .white-title {{ | |
| background: #fff; | |
| color: #333; | |
| font-weight: bold; | |
| font-size: 2rem; | |
| padding: 1rem 2rem; | |
| border-radius: 12px; | |
| text-align: center; | |
| margin-bottom: 24px; | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.08); | |
| border: 1px solid #eee; | |
| }} | |
| """, | |
| ) | |
| demo.launch(debug=True) | |