import gradio as gr
import torch
import os
from PIL import Image
import cairosvg
import io
import tempfile
import argparse
import gc
import yaml
import glob
import numpy as np
import time
import threading
import spaces
from huggingface_hub import hf_hub_download, snapshot_download
from decoder import SketchDecoder
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from tokenizer import SVGTokenizer
# Load config
CONFIG_PATH = './config.yaml'
with open(CONFIG_PATH, 'r') as f:
config = yaml.safe_load(f)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
# Global Models (will be loaded based on selected model size)
tokenizer = None
processor = None
sketch_decoder = None
svg_tokenizer = None
current_model_size = None # Track which model is currently loaded
# Thread lock for model inference
generation_lock = threading.Lock()
model_loading_lock = threading.Lock()
# Constants from config
SYSTEM_PROMPT = """You are an expert SVG code generator.
Generate precise, valid SVG path commands that accurately represent the described scene or object.
Focus on capturing key shapes, spatial relationships, and visual composition."""
SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
AVAILABLE_MODEL_SIZES = list(config.get('models', {}).keys())
DEFAULT_MODEL_SIZE = config.get('default_model_size', '8B')
# ============================================================
# Helper function to get config value (model-specific or shared)
# ============================================================
def get_config_value(model_size, *keys):
"""Get config value with model-specific override support."""
# Try model-specific config first
model_cfg = config.get('models', {}).get(model_size, {})
value = model_cfg
for key in keys:
if isinstance(value, dict) and key in value:
value = value[key]
else:
value = None
break
# Fallback to shared config if not found
if value is None:
value = config
for key in keys:
if isinstance(value, dict) and key in value:
value = value[key]
else:
return None
return value
# ============================================================
# Image processing settings from config (shared)
# ============================================================
image_config = config.get('image', {})
TARGET_IMAGE_SIZE = image_config.get('target_size', 448)
RENDER_SIZE = image_config.get('render_size', 512)
BACKGROUND_THRESHOLD = image_config.get('background_threshold', 240)
EMPTY_THRESHOLD_ILLUSTRATION = image_config.get('empty_threshold_illustration', 250)
EMPTY_THRESHOLD_ICON = image_config.get('empty_threshold_icon', 252)
EDGE_SAMPLE_RATIO = image_config.get('edge_sample_ratio', 0.1)
COLOR_SIMILARITY_THRESHOLD = image_config.get('color_similarity_threshold', 30)
MIN_EDGE_SAMPLES = image_config.get('min_edge_samples', 10)
# ============================================================
# Color settings from config (shared)
# ============================================================
colors_config = config.get('colors', {})
BLACK_COLOR_TOKEN = colors_config.get('black_color_token',
colors_config.get('color_token_start', 40010) + 2)
# ============================================================
# Model settings from config (shared)
# ============================================================
model_config = config.get('model', {})
BOS_TOKEN_ID = model_config.get('bos_token_id', 196998)
EOS_TOKEN_ID = model_config.get('eos_token_id', 196999)
PAD_TOKEN_ID = model_config.get('pad_token_id', 151643)
MAX_LENGTH = model_config.get('max_length', 1024)
# Max length limits for UI - reduced for faster generation
MAX_LENGTH_MIN = 256
MAX_LENGTH_MAX = 2048
MAX_LENGTH_DEFAULT = 512 # Reduced default for faster generation
# ============================================================
# Task configurations with defaults from config (shared)
# ============================================================
task_config = config.get('task_configs', {})
TASK_CONFIGS = {
"text-to-svg-icon": task_config.get('text_to_svg_icon', {
"default_temperature": 0.5,
"default_top_p": 0.88,
"default_top_k": 50,
"default_repetition_penalty": 1.05,
}),
"text-to-svg-illustration": task_config.get('text_to_svg_illustration', {
"default_temperature": 0.6,
"default_top_p": 0.90,
"default_top_k": 60,
"default_repetition_penalty": 1.03,
}),
"image-to-svg": task_config.get('image_to_svg', {
"default_temperature": 0.3,
"default_top_p": 0.90,
"default_top_k": 50,
"default_repetition_penalty": 1.05,
})
}
# ============================================================
# Generation parameters from config (shared)
# ============================================================
gen_config = config.get('generation', {})
DEFAULT_NUM_CANDIDATES = 1 # Changed to 1 to save GPU quota
MAX_NUM_CANDIDATES = 4 # Reduced max to save GPU quota
EXTRA_CANDIDATES_BUFFER = 2 # Reduced buffer
# ============================================================
# Validation settings from config (shared)
# ============================================================
validation_config = config.get('validation', {})
MIN_SVG_LENGTH = validation_config.get('min_svg_length', 20)
# Custom CSS
CUSTOM_CSS = """
/* Main container centering */
.gradio-container {
max-width: 1400px !important;
margin: 0 auto !important;
padding: 20px !important;
}
/* Header styling */
.header-container {
text-align: center;
margin-bottom: 20px;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 16px;
color: white;
}
.header-container h1 {
margin: 0;
font-size: 2.5em;
font-weight: 700;
}
.header-container p {
margin: 10px 0 0 0;
opacity: 0.9;
font-size: 1.1em;
}
/* Model selector styling */
.model-selector {
background: #f0f4f8;
border: 2px solid #667eea;
border-radius: 12px;
padding: 15px;
margin-bottom: 20px;
}
.model-selector-title {
font-weight: 700;
color: #667eea;
margin-bottom: 10px;
}
/* Tips section */
.tips-box {
background: #f8f9fa;
border-radius: 12px;
padding: 20px;
margin-bottom: 20px;
border: 1px solid #e0e0e0;
}
.tips-box h3 {
margin-top: 0;
color: #333;
border-bottom: 2px solid #667eea;
padding-bottom: 10px;
}
.tip-category {
background: white;
border-radius: 8px;
padding: 15px;
margin: 10px 0;
border-left: 4px solid #667eea;
}
.tip-category h4 {
margin: 0 0 10px 0;
color: #667eea;
}
.tip-category code {
background: #f0f0f0;
padding: 2px 6px;
border-radius: 4px;
font-size: 0.9em;
}
.example-prompt {
background: #e8f4fd;
padding: 10px;
border-radius: 6px;
margin: 8px 0;
font-style: italic;
font-size: 0.95em;
color: #333;
}
.red-tip {
color: #dc3545;
font-weight: 600;
}
.red-box {
background: #fff5f5;
border: 1px solid #ffcccc;
border-left: 4px solid #dc3545;
padding: 12px;
border-radius: 8px;
margin: 10px 0;
}
.red-box strong {
color: #dc3545;
}
.orange-box {
background: #fff8e6;
border: 1px solid #ffc107;
border-left: 4px solid #ff9800;
padding: 12px;
border-radius: 8px;
margin: 10px 0;
}
.orange-box strong {
color: #ff9800;
}
.green-box {
background: #e8f5e9;
border: 1px solid #81c784;
border-left: 4px solid #4caf50;
padding: 12px;
border-radius: 8px;
margin: 10px 0;
}
.green-box strong {
color: #4caf50;
}
.blue-box {
background: #e3f2fd;
border: 1px solid #90caf9;
border-left: 4px solid #2196f3;
padding: 12px;
border-radius: 8px;
margin: 10px 0;
}
.blue-box strong {
color: #2196f3;
}
/* Tab styling */
.tabs {
border-radius: 12px !important;
overflow: hidden;
}
.tabitem {
padding: 20px !important;
}
/* Button styling */
.primary-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
font-weight: 600 !important;
padding: 12px 24px !important;
font-size: 1.1em !important;
}
.primary-btn:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
}
/* Settings group */
.settings-group {
background: #f8f9fa;
border-radius: 10px;
padding: 15px;
margin: 10px 0;
}
.advanced-settings {
background: #f0f4f8;
border-radius: 8px;
padding: 12px;
margin-top: 10px;
}
/* Code output */
.code-output textarea {
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important;
font-size: 12px !important;
background: #1e1e1e !important;
color: #d4d4d4 !important;
border-radius: 8px !important;
}
/* Input image area */
.input-image {
border: 2px dashed #ccc;
border-radius: 12px;
transition: border-color 0.3s;
}
.input-image:hover {
border-color: #667eea;
}
/* Footer */
.footer {
text-align: center;
padding: 20px;
color: #666;
font-size: 0.9em;
}
/* Responsive adjustments */
@media (max-width: 768px) {
.gradio-container {
padding: 10px !important;
}
.header-container h1 {
font-size: 1.8em;
}
}
"""
# Simplified Tips HTML for bottom of page
TIPS_HTML_BOTTOM = """
💡 Tips & Guide
🎲 Not getting the result you want?
This is normal! Just click "Generate SVG" again to re-roll. Each generation is different - try 2-3 times to find the best result!
📝 Prompting Tips
- Use geometric descriptions: "triangular roof", "circular head", "oval body", "curved tail"
- Specify colors for EACH element: "red roof", "blue shirt", "black outline", "green grass"
- Keep it simple: Use short, clear phrases connected by commas
- Add positions: "at top", "in center", "at bottom", "facing right"
⚙️ Parameter Guide
- Max Length: Lower (256-1024) = faster & simpler | Higher (1024-2048) = slower & more detailed
- Temperature: Lower (0.2-0.4) = more accurate | Higher (0.5-0.7) = more creative
- Messy result? Lower temperature and top_k
- Too simple? Increase max_length and temperature
✨ Recommended Prompt Structure
[Subject] + [Shape descriptions with colors] + [Position] + [Style]
Example: "A fox logo: triangular orange head, pointed ears, white chest marking, facing right. Minimalist flat style."
"""
# Image-to-SVG specific tips (simplified)
IMAGE_TIPS_HTML = """
🎲 Tips for Best Results
- Simple images work best: Clean backgrounds, clear shapes
- Not satisfied? Just click generate again to re-roll!
- PNG with transparency is automatically converted to white background
"""
def parse_args():
parser = argparse.ArgumentParser(description='SVG Generator Service')
parser.add_argument('--listen', type=str, default='0.0.0.0')
parser.add_argument('--port', type=int, default=7860)
parser.add_argument('--share', action='store_true')
parser.add_argument('--debug', action='store_true')
return parser.parse_args()
def download_model_weights(repo_id: str, filename: str = "pytorch_model.bin") -> str:
"""
Download model weights from Hugging Face Hub.
"""
print(f"Downloading {filename} from {repo_id}...")
try:
local_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
resume_download=True,
)
print(f"Successfully downloaded to: {local_path}")
return local_path
except Exception as e:
print(f"Error downloading from {repo_id}: {e}")
raise
def is_local_path(path: str) -> bool:
"""Check if a path is a local filesystem path or a HuggingFace repo ID."""
if os.path.exists(path):
return True
if path.startswith('/') or path.startswith('./') or path.startswith('../'):
return True
if os.path.sep in path and os.path.exists(os.path.dirname(path)):
return True
if len(path) > 1 and path[1] == ':':
return True
return False
def load_models(model_size: str, weight_path: str = None, model_path: str = None):
"""
Load all models for a specific model size.
"""
global tokenizer, processor, sketch_decoder, svg_tokenizer, current_model_size
# Use config values if not provided
if weight_path is None:
weight_path = get_config_value(model_size, 'huggingface', 'omnisvg_model')
if model_path is None:
model_path = get_config_value(model_size, 'huggingface', 'qwen_model')
print(f"\n{'='*60}")
print(f"Loading {model_size} Model")
print(f"{'='*60}")
print(f"Qwen model: {model_path}")
print(f"OmniSVG weights: {weight_path}")
print(f"Precision: {DTYPE}")
# Load Qwen tokenizer and processor
print("\n[1/3] Loading tokenizer and processor...")
tokenizer = AutoTokenizer.from_pretrained(
model_path,
padding_side="left",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(
model_path,
padding_side="left",
trust_remote_code=True
)
processor.tokenizer.padding_side = "left"
print("Tokenizer and processor loaded successfully!")
# Initialize sketch decoder with model_size
print("\n[2/3] Initializing SketchDecoder...")
sketch_decoder = SketchDecoder(
config_path=CONFIG_PATH,
model_path=model_path,
model_size=model_size,
pix_len=MAX_LENGTH_MAX, # Use max possible length for model initialization
text_len=config.get('text', {}).get('max_length', 200),
torch_dtype=DTYPE
)
# Load OmniSVG weights
print("\n[3/3] Loading OmniSVG weights...")
if is_local_path(weight_path):
bin_path = os.path.join(weight_path, "pytorch_model.bin")
if not os.path.exists(bin_path):
if os.path.exists(weight_path) and weight_path.endswith('.bin'):
bin_path = weight_path
else:
raise FileNotFoundError(
f"Could not find pytorch_model.bin at {weight_path}. "
f"Please provide a valid local path or HuggingFace repo ID."
)
print(f"Loading weights from local path: {bin_path}")
else:
print(f"Downloading weights from HuggingFace: {weight_path}")
bin_path = download_model_weights(weight_path, "pytorch_model.bin")
state_dict = torch.load(bin_path, map_location='cpu')
sketch_decoder.load_state_dict(state_dict)
print("OmniSVG weights loaded successfully!")
sketch_decoder = sketch_decoder.to(device).eval()
# Initialize SVG tokenizer with model_size
svg_tokenizer = SVGTokenizer(CONFIG_PATH, model_size=model_size)
current_model_size = model_size
print("\n" + "="*60)
print(f"All {model_size} models loaded successfully!")
print("="*60 + "\n")
def ensure_model_loaded(model_size: str):
"""
Ensure the specified model is loaded. Load or switch if necessary.
This function should be called within @spaces.GPU decorated functions.
"""
global current_model_size, sketch_decoder, tokenizer, processor, svg_tokenizer
if current_model_size == model_size and sketch_decoder is not None:
return # Already loaded
with model_loading_lock:
# Double-check after acquiring lock
if current_model_size == model_size and sketch_decoder is not None:
return
# Clear old models if switching
if current_model_size is not None:
print(f"Switching from {current_model_size} to {model_size}...")
del sketch_decoder
del tokenizer
del processor
del svg_tokenizer
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Load new model
load_models(model_size)
def detect_text_subtype(text_prompt):
"""Auto-detect text prompt subtype"""
text_lower = text_prompt.lower()
icon_keywords = ['icon', 'logo', 'symbol', 'badge', 'button', 'emoji', 'glyph', 'simple',
'arrow', 'triangle', 'circle', 'square', 'heart', 'star', 'checkmark']
if any(kw in text_lower for kw in icon_keywords):
return "icon"
illustration_keywords = [
'illustration', 'scene', 'person', 'people', 'character', 'man', 'woman', 'boy', 'girl',
'avatar', 'portrait', 'face', 'head', 'body',
'cat', 'dog', 'bird', 'animal', 'pet', 'fox', 'rabbit',
'sitting', 'standing', 'walking', 'running', 'sleeping', 'holding', 'playing',
'house', 'building', 'tree', 'garden', 'landscape', 'mountain', 'forest', 'city',
'ocean', 'beach', 'sunset', 'sunrise', 'sky'
]
match_count = sum(1 for kw in illustration_keywords if kw in text_lower)
if match_count >= 1 or len(text_prompt) > 50:
return "illustration"
return "icon"
def detect_and_replace_background(image, threshold=None, edge_sample_ratio=None):
"""
Detect if image has non-white background and optionally replace it.
"""
if threshold is None:
threshold = BACKGROUND_THRESHOLD
if edge_sample_ratio is None:
edge_sample_ratio = EDGE_SAMPLE_RATIO
img_array = np.array(image)
if image.mode == 'RGBA':
bg = Image.new('RGBA', image.size, (255, 255, 255, 255))
composite = Image.alpha_composite(bg, image)
return composite.convert('RGB'), True
h, w = img_array.shape[:2]
edge_pixels = []
sample_count = max(MIN_EDGE_SAMPLES, int(min(h, w) * edge_sample_ratio))
for i in range(0, w, max(1, w // sample_count)):
edge_pixels.append(img_array[0, i])
edge_pixels.append(img_array[h-1, i])
for i in range(0, h, max(1, h // sample_count)):
edge_pixels.append(img_array[i, 0])
edge_pixels.append(img_array[i, w-1])
edge_pixels = np.array(edge_pixels)
if len(edge_pixels) > 0:
mean_edge = edge_pixels.mean(axis=0)
if np.all(mean_edge > threshold):
return image, False
if len(img_array.shape) == 3 and img_array.shape[2] >= 3:
if img_array.shape[2] == 4:
gray = np.mean(img_array[:, :, :3], axis=2)
else:
gray = np.mean(img_array, axis=2)
edge_colors = []
for i in range(w):
edge_colors.append(tuple(img_array[0, i, :3]))
edge_colors.append(tuple(img_array[h-1, i, :3]))
for i in range(h):
edge_colors.append(tuple(img_array[i, 0, :3]))
edge_colors.append(tuple(img_array[i, w-1, :3]))
from collections import Counter
color_counts = Counter(edge_colors)
bg_color = color_counts.most_common(1)[0][0]
color_diff = np.sqrt(np.sum((img_array[:, :, :3].astype(float) - np.array(bg_color)) ** 2, axis=2))
bg_mask = color_diff < COLOR_SIMILARITY_THRESHOLD
result = img_array.copy()
if result.shape[2] == 4:
result[bg_mask] = [255, 255, 255, 255]
else:
result[bg_mask] = [255, 255, 255]
return Image.fromarray(result).convert('RGB'), True
return image, False
def preprocess_image_for_svg(image, replace_background=True, target_size=None):
"""
Preprocess image for SVG generation.
"""
if target_size is None:
target_size = TARGET_IMAGE_SIZE
if isinstance(image, str):
raw_img = Image.open(image)
else:
raw_img = image
was_modified = False
if raw_img.mode == 'RGBA':
bg = Image.new('RGBA', raw_img.size, (255, 255, 255, 255))
img_with_bg = Image.alpha_composite(bg, raw_img).convert('RGB')
was_modified = True
elif raw_img.mode == 'LA' or raw_img.mode == 'PA':
raw_img = raw_img.convert('RGBA')
bg = Image.new('RGBA', raw_img.size, (255, 255, 255, 255))
img_with_bg = Image.alpha_composite(bg, raw_img).convert('RGB')
was_modified = True
elif raw_img.mode != 'RGB':
img_with_bg = raw_img.convert('RGB')
else:
img_with_bg = raw_img
if replace_background:
img_with_bg, bg_replaced = detect_and_replace_background(img_with_bg)
was_modified = was_modified or bg_replaced
img_resized = img_with_bg.resize((target_size, target_size), Image.Resampling.LANCZOS)
return img_resized, was_modified
def prepare_inputs(task_type, content):
"""Prepare model inputs"""
if task_type == "text-to-svg":
prompt_text = str(content).strip()
instruction = f"""Generate an SVG illustration for: {prompt_text}
Requirements:
- Create complete SVG path commands
- Include proper coordinates and colors
- Maintain visual clarity and composition"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": [{"type": "text", "text": instruction}]}
]
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text_input], padding=True, truncation=True, return_tensors="pt")
else: # image-to-svg
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": [
{"type": "text", "text": "Generate SVG code that accurately represents this image:"},
{"type": "image", "image": content},
]}
]
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = processor(text=[text_input], images=image_inputs, padding=True, truncation=True, return_tensors="pt")
return inputs
def render_svg_to_image(svg_str, size=None):
"""Render SVG to high-quality PIL Image"""
if size is None:
size = RENDER_SIZE
try:
png_data = cairosvg.svg2png(
bytestring=svg_str.encode('utf-8'),
output_width=size,
output_height=size
)
image_rgba = Image.open(io.BytesIO(png_data)).convert("RGBA")
bg = Image.new("RGB", image_rgba.size, (255, 255, 255))
bg.paste(image_rgba, mask=image_rgba.split()[3])
return bg
except Exception as e:
print(f"Render error: {e}")
return None
def create_gallery_html(candidates, cols=4):
"""Create HTML gallery for multiple SVG candidates"""
if not candidates:
return 'No candidates generated
'
items_html = []
for i, cand in enumerate(candidates):
svg_str = cand['svg']
if 'viewBox' not in svg_str:
svg_str = svg_str.replace('