#!/usr/bin/env python # coding: utf-8 import torch from PIL import Image import re import base64 import gradio as gr from transformers import DonutProcessor, VisionEncoderDecoderModel # Task prompt and model path task_prompt = f"" # decoder 질문. 파싱해라.라는 의미. pretrained_path = "sma1-rmarud/donut-cord-v2-menu-sample-demo" # Load pretrained processor and model processor = DonutProcessor.from_pretrained(pretrained_path) pretrained_model = VisionEncoderDecoderModel.from_pretrained(pretrained_path) device = torch.device("cpu") # CPU 사용 pretrained_model.to(device) pretrained_model = pretrained_model.float() pretrained_model.eval() # Function to convert tokenized output to JSON format def token2json(tokens, is_inner_value=False): output = dict() while tokens: start_token = re.search(r"", tokens, re.IGNORECASE) if start_token is None: break key = start_token.group(1) end_token = re.search(fr"", tokens, re.IGNORECASE) start_token = start_token.group() if end_token is None: tokens = tokens.replace(start_token, "") else: end_token = end_token.group() start_token_escaped = re.escape(start_token) end_token_escaped = re.escape(end_token) content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE) if content is not None: content = content.group(1).strip() if r""): leaf = leaf.strip() output[key].append(leaf) if len(output[key]) == 1: output[key] = output[key][0] tokens = tokens[tokens.find(end_token) + len(end_token):].strip() if tokens[:6] == r"": # non-leaf nodes return [output] + token2json(tokens[6:], is_inner_value=True) if len(output): return [output] if is_inner_value else output else: return [] if is_inner_value else {"text_sequence": tokens} # Gradio demo process function def demo_process(files): global pretrained_model, task_prompt, device results = [] for file in files: input_img = Image.open(file).convert("RGB") input_img = input_img.resize((960, 640)) pixel_values = processor(input_img, return_tensors="pt", padding=True).pixel_values.to(device) pixel_values = pixel_values.float() decoder_input_ids = torch.full((1, 1), pretrained_model.config.decoder_start_token_id, device=device) outputs = pretrained_model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=pretrained_model.config.decoder.max_length, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True,) predictions = [] for seq in processor.tokenizer.batch_decode(outputs.sequences): seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") seq = re.sub(r"<.*?>", "", seq, count=1).strip() predictions.append(seq) results.append(token2json(predictions[0])) return results # Base64 encode the SVG background sprinkle_svg = """ """ encoded_svg = base64.b64encode(sprinkle_svg.encode('utf-8')).decode('ascii') background_url = f"data:image/svg+xml;base64,{encoded_svg}" # Gradio Interface demo = gr.Interface( fn=demo_process, inputs=gr.File(file_types=["image"], label="Upload multiple images", file_count="multiple"), outputs=gr.JSON(), # json list output description="
Donut 🍩 Demonstration
", # Title for demo theme="soft", css=f""" .gradio-container {{ background-color: #ffffff; background-image: url('{background_url}'); background-size: cover; background-repeat: no-repeat; background-position: center; }} .white-title {{ background: #fff; color: #333; font-weight: bold; font-size: 2rem; padding: 1rem 2rem; border-radius: 12px; text-align: center; margin-bottom: 24px; box-shadow: 0 2px 8px rgba(0,0,0,0.08); border: 1px solid #eee; }} """, ) demo.launch(debug=True)