Spaces:
Running
on
Zero
Running
on
Zero
Update tokenizer.py
Browse files- tokenizer.py +78 -44
tokenizer.py
CHANGED
|
@@ -8,38 +8,82 @@ from deepsvg.svglib.geom import Bbox
|
|
| 8 |
|
| 9 |
|
| 10 |
class SVGTokenizer:
|
| 11 |
-
"""SVG tokenizer -
|
| 12 |
|
| 13 |
-
def __init__(self, config_path: str = "./config.yaml"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
with open(config_path, 'r') as f:
|
| 15 |
self.config = yaml.safe_load(f)
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
self._load_config()
|
| 18 |
self.pixel2xy = self._create_pixel2xy_mapping()
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def _load_config(self):
|
| 21 |
-
"""
|
| 22 |
-
# ========== Token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
tokens_cfg = self.config['tokens']
|
| 24 |
self.NUM_SVG_END = tokens_cfg['svg_end']
|
| 25 |
-
self.BASE_OFFSET = tokens_cfg['base_offset']
|
| 26 |
-
self.NUM_MASK_AND_EOM = tokens_cfg['num_mask_and_eom']
|
| 27 |
self.NUM_END_TOKEN = tokens_cfg['num_end_token']
|
| 28 |
|
| 29 |
-
# ==========
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
coords_cfg = self.config['coordinates']
|
| 31 |
self.BBOX = coords_cfg['bbox']
|
| 32 |
-
self.PIX_PAD = coords_cfg['pix_pad_offset']
|
| 33 |
-
self.COORD_PAD = coords_cfg['coord_pad_offset']
|
| 34 |
|
| 35 |
-
# ==========
|
| 36 |
colors_cfg = self.config['colors']
|
| 37 |
self.COLOR_TOKEN_START_RAW = colors_cfg['color_token_start']
|
| 38 |
-
self.COLOR_START_OFFSET = colors_cfg['color_start_offset']
|
| 39 |
-
self.COLOR_END_OFFSET = colors_cfg['color_end_offset']
|
| 40 |
self.MAX_COLOR_TOKENS = colors_cfg['max_color_tokens']
|
| 41 |
|
| 42 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
commands_cfg = self.config['svg_commands']
|
| 44 |
self.CMD_MOVE = commands_cfg['move']
|
| 45 |
self.CMD_LINE = commands_cfg['line']
|
|
@@ -47,41 +91,37 @@ class SVGTokenizer:
|
|
| 47 |
self.CMD_ARC = commands_cfg['arc']
|
| 48 |
self.CMD_CLOSE = commands_cfg['close']
|
| 49 |
|
| 50 |
-
# ==========
|
| 51 |
model_cfg = self.config['model']
|
| 52 |
self.BOS_TOKEN_ID = model_cfg['bos_token_id']
|
| 53 |
self.EOS_TOKEN_ID = model_cfg['eos_token_id']
|
| 54 |
self.PAD_TOKEN_ID = model_cfg['pad_token_id']
|
| 55 |
|
| 56 |
-
# ========== Arc
|
| 57 |
arc_cfg = self.config.get('arc', {})
|
| 58 |
self.ARC_PARAM_OFFSET = arc_cfg.get('param_offset', 44500)
|
| 59 |
self.ARC_PARAM_RANGE = arc_cfg.get('param_range', 100)
|
| 60 |
self.ARC_PARAM_START = self.ARC_PARAM_OFFSET + self.BASE_OFFSET
|
| 61 |
|
| 62 |
-
# ==========
|
| 63 |
-
# PIXEL_OFFSET: 从配置推导
|
| 64 |
-
# 命令token存储值 - BASE_OFFSET - PIXEL_OFFSET = CMD_MOVE
|
| 65 |
-
# (NUM_MASK_AND_EOM + NUM_SVG_END) - BASE_OFFSET - PIXEL_OFFSET = CMD_MOVE
|
| 66 |
self.PIXEL_OFFSET = (self.NUM_MASK_AND_EOM - self.BASE_OFFSET +
|
| 67 |
self.NUM_SVG_END - self.CMD_MOVE)
|
| 68 |
|
| 69 |
-
#
|
| 70 |
self.CMD_TOKEN_START = self.NUM_MASK_AND_EOM + self.NUM_SVG_END
|
| 71 |
self.CMD_TOKEN_END = self.PIX_PAD + self.NUM_SVG_END
|
| 72 |
|
| 73 |
-
#
|
| 74 |
self.COORD_TOKEN_START = self.PIX_PAD + self.NUM_SVG_END
|
| 75 |
|
| 76 |
-
#
|
| 77 |
self.COLOR_COORD_BOUNDARY = self.COLOR_TOKEN_START_RAW + 1 + self.BASE_OFFSET
|
| 78 |
|
| 79 |
-
#
|
| 80 |
-
# 减去 PIXEL_OFFSET 后的颜色token下限
|
| 81 |
self.COLOR_THRESHOLD = self.COLOR_TOKEN_START_RAW - self.PIXEL_OFFSET + 1
|
| 82 |
|
| 83 |
def _create_pixel2xy_mapping(self) -> Dict[int, np.ndarray]:
|
| 84 |
-
"""
|
| 85 |
pixel2xy = {}
|
| 86 |
x = np.linspace(0, self.BBOX - 1, self.BBOX)
|
| 87 |
y = np.linspace(0, self.BBOX - 1, self.BBOX)
|
|
@@ -89,13 +129,12 @@ class SVGTokenizer:
|
|
| 89 |
xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int)
|
| 90 |
|
| 91 |
for pixel, xy in enumerate(xy_grid):
|
| 92 |
-
# xy + COORD_PAD + NUM_SVG_END
|
| 93 |
pixel2xy[pixel] = xy + self.COORD_PAD + self.NUM_SVG_END
|
| 94 |
|
| 95 |
return pixel2xy
|
| 96 |
|
| 97 |
def token_to_color(self, color_token: int) -> str:
|
| 98 |
-
"""
|
| 99 |
try:
|
| 100 |
if color_token == self.COLOR_TOKEN_START_RAW:
|
| 101 |
return "none"
|
|
@@ -123,37 +162,35 @@ class SVGTokenizer:
|
|
| 123 |
return "#808080"
|
| 124 |
|
| 125 |
def process_generated_tokens(self, output_ids: torch.Tensor) -> np.ndarray:
|
| 126 |
-
"""
|
| 127 |
-
|
| 128 |
-
"""
|
| 129 |
-
# 移除 bos/eos
|
| 130 |
generated_pixels = output_ids[:, 1:-1].cpu().numpy().flatten()
|
| 131 |
|
| 132 |
sample_xys = []
|
| 133 |
|
| 134 |
for pixel in generated_pixels:
|
| 135 |
try:
|
| 136 |
-
# 1.
|
| 137 |
if self.CMD_TOKEN_START <= pixel < self.CMD_TOKEN_END:
|
| 138 |
xy = np.array([pixel - self.BASE_OFFSET,
|
| 139 |
pixel - self.BASE_OFFSET]).astype(int)
|
| 140 |
sample_xys.append(xy)
|
| 141 |
|
| 142 |
-
# 2.
|
| 143 |
elif self.COORD_TOKEN_START <= pixel < self.COLOR_COORD_BOUNDARY:
|
| 144 |
pixel_index = pixel - self.COORD_TOKEN_START
|
| 145 |
if pixel_index in self.pixel2xy:
|
| 146 |
xy = self.pixel2xy[pixel_index] - self.BASE_OFFSET
|
| 147 |
sample_xys.append(xy)
|
| 148 |
|
| 149 |
-
# 3. Arc
|
| 150 |
elif (self.ARC_PARAM_START + 1 <= pixel <
|
| 151 |
self.ARC_PARAM_START + 1 + self.ARC_PARAM_RANGE):
|
| 152 |
value = pixel - self.ARC_PARAM_START - 1
|
| 153 |
xy = np.array([value, value]).astype(int)
|
| 154 |
sample_xys.append(xy)
|
| 155 |
|
| 156 |
-
# 4.
|
| 157 |
elif self.COLOR_COORD_BOUNDARY <= pixel < self.ARC_PARAM_START:
|
| 158 |
xy = np.array([pixel - self.BASE_OFFSET,
|
| 159 |
pixel - self.BASE_OFFSET]).astype(int)
|
|
@@ -169,15 +206,12 @@ class SVGTokenizer:
|
|
| 169 |
return np.array([]).reshape(0, 2)
|
| 170 |
|
| 171 |
def raster_svg(self, pixels: np.ndarray) -> Tuple[List[List[torch.Tensor]], List[int]]:
|
| 172 |
-
"""
|
| 173 |
-
按照 dataset.py 的 raster_svg 逻辑
|
| 174 |
-
关键:pixels -= PIXEL_OFFSET 是核心转换步骤
|
| 175 |
-
"""
|
| 176 |
try:
|
| 177 |
if len(pixels) == 0:
|
| 178 |
return [[]], []
|
| 179 |
|
| 180 |
-
#
|
| 181 |
pixels = pixels - self.PIXEL_OFFSET
|
| 182 |
|
| 183 |
svg_tensors = []
|
|
@@ -250,11 +284,11 @@ class SVGTokenizer:
|
|
| 250 |
path_tensor.append(cmd_tensor.tolist())
|
| 251 |
i += 2
|
| 252 |
|
| 253 |
-
#
|
| 254 |
elif pix[0] >= self.COLOR_THRESHOLD:
|
| 255 |
if path_tensor:
|
| 256 |
svg_tensors.append(torch.tensor(path_tensor))
|
| 257 |
-
#
|
| 258 |
color_token = int(pix[0] + self.PIXEL_OFFSET - 1)
|
| 259 |
color_tensors.append(color_token)
|
| 260 |
path_tensor = []
|
|
@@ -266,7 +300,7 @@ class SVGTokenizer:
|
|
| 266 |
print(f"Error at position {i}: {e}")
|
| 267 |
break
|
| 268 |
|
| 269 |
-
#
|
| 270 |
if path_tensor:
|
| 271 |
svg_tensors.append(torch.tensor(path_tensor))
|
| 272 |
|
|
@@ -280,7 +314,7 @@ class SVGTokenizer:
|
|
| 280 |
|
| 281 |
def apply_colors_to_svg(self, svg_tensors: List[torch.Tensor],
|
| 282 |
colors: Optional[List[int]]) -> SVG:
|
| 283 |
-
"""
|
| 284 |
paths = []
|
| 285 |
|
| 286 |
if not svg_tensors:
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class SVGTokenizer:
|
| 11 |
+
"""SVG tokenizer - supports both 8B and 4B models via config.yaml"""
|
| 12 |
|
| 13 |
+
def __init__(self, config_path: str = "./config.yaml", model_size: str = None):
|
| 14 |
+
"""
|
| 15 |
+
Initialize SVGTokenizer.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
config_path: Path to config.yaml
|
| 19 |
+
model_size: Model size ("8B" or "4B"). If None, uses default from config.
|
| 20 |
+
"""
|
| 21 |
with open(config_path, 'r') as f:
|
| 22 |
self.config = yaml.safe_load(f)
|
| 23 |
|
| 24 |
+
# Determine model size
|
| 25 |
+
self.model_size = model_size or self.config.get('default_model_size', '8B')
|
| 26 |
+
if self.model_size not in self.config.get('models', {}):
|
| 27 |
+
raise ValueError(f"Invalid model_size: {self.model_size}. Must be one of: {list(self.config.get('models', {}).keys())}")
|
| 28 |
+
|
| 29 |
self._load_config()
|
| 30 |
self.pixel2xy = self._create_pixel2xy_mapping()
|
| 31 |
|
| 32 |
+
def _get_model_specific_config(self, *keys):
|
| 33 |
+
"""Get model-specific config value, with fallback to shared config."""
|
| 34 |
+
model_cfg = self.config.get('models', {}).get(self.model_size, {})
|
| 35 |
+
|
| 36 |
+
# Navigate through nested keys in model-specific config
|
| 37 |
+
value = model_cfg
|
| 38 |
+
for key in keys:
|
| 39 |
+
if isinstance(value, dict) and key in value:
|
| 40 |
+
value = value[key]
|
| 41 |
+
else:
|
| 42 |
+
value = None
|
| 43 |
+
break
|
| 44 |
+
|
| 45 |
+
# If not found in model-specific, try shared config
|
| 46 |
+
if value is None:
|
| 47 |
+
value = self.config
|
| 48 |
+
for key in keys:
|
| 49 |
+
if isinstance(value, dict) and key in value:
|
| 50 |
+
value = value[key]
|
| 51 |
+
else:
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
return value
|
| 55 |
+
|
| 56 |
def _load_config(self):
|
| 57 |
+
"""Load all constants from configuration file with model-specific overrides."""
|
| 58 |
+
# ========== Token-related configs ==========
|
| 59 |
+
# Model-specific tokens
|
| 60 |
+
self.NUM_MASK_AND_EOM = self._get_model_specific_config('tokens', 'num_mask_and_eom')
|
| 61 |
+
self.BASE_OFFSET = self._get_model_specific_config('tokens', 'base_offset')
|
| 62 |
+
|
| 63 |
+
# Shared tokens
|
| 64 |
tokens_cfg = self.config['tokens']
|
| 65 |
self.NUM_SVG_END = tokens_cfg['svg_end']
|
|
|
|
|
|
|
| 66 |
self.NUM_END_TOKEN = tokens_cfg['num_end_token']
|
| 67 |
|
| 68 |
+
# ========== Coordinate-related configs ==========
|
| 69 |
+
# Model-specific coordinates
|
| 70 |
+
self.PIX_PAD = self._get_model_specific_config('coordinates', 'pix_pad_offset')
|
| 71 |
+
self.COORD_PAD = self._get_model_specific_config('coordinates', 'coord_pad_offset')
|
| 72 |
+
|
| 73 |
+
# Shared coordinates
|
| 74 |
coords_cfg = self.config['coordinates']
|
| 75 |
self.BBOX = coords_cfg['bbox']
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
# ========== Color-related configs ==========
|
| 78 |
colors_cfg = self.config['colors']
|
| 79 |
self.COLOR_TOKEN_START_RAW = colors_cfg['color_token_start']
|
|
|
|
|
|
|
| 80 |
self.MAX_COLOR_TOKENS = colors_cfg['max_color_tokens']
|
| 81 |
|
| 82 |
+
# Model-specific colors
|
| 83 |
+
self.COLOR_START_OFFSET = self._get_model_specific_config('colors', 'color_start_offset')
|
| 84 |
+
self.COLOR_END_OFFSET = self._get_model_specific_config('colors', 'color_end_offset')
|
| 85 |
+
|
| 86 |
+
# ========== SVG command values ==========
|
| 87 |
commands_cfg = self.config['svg_commands']
|
| 88 |
self.CMD_MOVE = commands_cfg['move']
|
| 89 |
self.CMD_LINE = commands_cfg['line']
|
|
|
|
| 91 |
self.CMD_ARC = commands_cfg['arc']
|
| 92 |
self.CMD_CLOSE = commands_cfg['close']
|
| 93 |
|
| 94 |
+
# ========== Model-related configs ==========
|
| 95 |
model_cfg = self.config['model']
|
| 96 |
self.BOS_TOKEN_ID = model_cfg['bos_token_id']
|
| 97 |
self.EOS_TOKEN_ID = model_cfg['eos_token_id']
|
| 98 |
self.PAD_TOKEN_ID = model_cfg['pad_token_id']
|
| 99 |
|
| 100 |
+
# ========== Arc parameter configs ==========
|
| 101 |
arc_cfg = self.config.get('arc', {})
|
| 102 |
self.ARC_PARAM_OFFSET = arc_cfg.get('param_offset', 44500)
|
| 103 |
self.ARC_PARAM_RANGE = arc_cfg.get('param_range', 100)
|
| 104 |
self.ARC_PARAM_START = self.ARC_PARAM_OFFSET + self.BASE_OFFSET
|
| 105 |
|
| 106 |
+
# ========== Derived constants ==========
|
|
|
|
|
|
|
|
|
|
| 107 |
self.PIXEL_OFFSET = (self.NUM_MASK_AND_EOM - self.BASE_OFFSET +
|
| 108 |
self.NUM_SVG_END - self.CMD_MOVE)
|
| 109 |
|
| 110 |
+
# Command token range
|
| 111 |
self.CMD_TOKEN_START = self.NUM_MASK_AND_EOM + self.NUM_SVG_END
|
| 112 |
self.CMD_TOKEN_END = self.PIX_PAD + self.NUM_SVG_END
|
| 113 |
|
| 114 |
+
# Coordinate token start
|
| 115 |
self.COORD_TOKEN_START = self.PIX_PAD + self.NUM_SVG_END
|
| 116 |
|
| 117 |
+
# Color-coordinate boundary
|
| 118 |
self.COLOR_COORD_BOUNDARY = self.COLOR_TOKEN_START_RAW + 1 + self.BASE_OFFSET
|
| 119 |
|
| 120 |
+
# Color threshold for raster_svg
|
|
|
|
| 121 |
self.COLOR_THRESHOLD = self.COLOR_TOKEN_START_RAW - self.PIXEL_OFFSET + 1
|
| 122 |
|
| 123 |
def _create_pixel2xy_mapping(self) -> Dict[int, np.ndarray]:
|
| 124 |
+
"""Create pixel to xy mapping following dataset.py logic."""
|
| 125 |
pixel2xy = {}
|
| 126 |
x = np.linspace(0, self.BBOX - 1, self.BBOX)
|
| 127 |
y = np.linspace(0, self.BBOX - 1, self.BBOX)
|
|
|
|
| 129 |
xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int)
|
| 130 |
|
| 131 |
for pixel, xy in enumerate(xy_grid):
|
|
|
|
| 132 |
pixel2xy[pixel] = xy + self.COORD_PAD + self.NUM_SVG_END
|
| 133 |
|
| 134 |
return pixel2xy
|
| 135 |
|
| 136 |
def token_to_color(self, color_token: int) -> str:
|
| 137 |
+
"""Convert token to color following dataset.py logic."""
|
| 138 |
try:
|
| 139 |
if color_token == self.COLOR_TOKEN_START_RAW:
|
| 140 |
return "none"
|
|
|
|
| 162 |
return "#808080"
|
| 163 |
|
| 164 |
def process_generated_tokens(self, output_ids: torch.Tensor) -> np.ndarray:
|
| 165 |
+
"""Process generated tokens following dataset.py logic."""
|
| 166 |
+
# Remove bos/eos
|
|
|
|
|
|
|
| 167 |
generated_pixels = output_ids[:, 1:-1].cpu().numpy().flatten()
|
| 168 |
|
| 169 |
sample_xys = []
|
| 170 |
|
| 171 |
for pixel in generated_pixels:
|
| 172 |
try:
|
| 173 |
+
# 1. Command tokens: CMD_TOKEN_START <= pixel < CMD_TOKEN_END
|
| 174 |
if self.CMD_TOKEN_START <= pixel < self.CMD_TOKEN_END:
|
| 175 |
xy = np.array([pixel - self.BASE_OFFSET,
|
| 176 |
pixel - self.BASE_OFFSET]).astype(int)
|
| 177 |
sample_xys.append(xy)
|
| 178 |
|
| 179 |
+
# 2. Coordinate tokens: COORD_TOKEN_START <= pixel < COLOR_COORD_BOUNDARY
|
| 180 |
elif self.COORD_TOKEN_START <= pixel < self.COLOR_COORD_BOUNDARY:
|
| 181 |
pixel_index = pixel - self.COORD_TOKEN_START
|
| 182 |
if pixel_index in self.pixel2xy:
|
| 183 |
xy = self.pixel2xy[pixel_index] - self.BASE_OFFSET
|
| 184 |
sample_xys.append(xy)
|
| 185 |
|
| 186 |
+
# 3. Arc parameters: ARC_PARAM_START + 1 <= pixel < ARC_PARAM_START + 1 + ARC_PARAM_RANGE
|
| 187 |
elif (self.ARC_PARAM_START + 1 <= pixel <
|
| 188 |
self.ARC_PARAM_START + 1 + self.ARC_PARAM_RANGE):
|
| 189 |
value = pixel - self.ARC_PARAM_START - 1
|
| 190 |
xy = np.array([value, value]).astype(int)
|
| 191 |
sample_xys.append(xy)
|
| 192 |
|
| 193 |
+
# 4. Color tokens: COLOR_COORD_BOUNDARY <= pixel < ARC_PARAM_START
|
| 194 |
elif self.COLOR_COORD_BOUNDARY <= pixel < self.ARC_PARAM_START:
|
| 195 |
xy = np.array([pixel - self.BASE_OFFSET,
|
| 196 |
pixel - self.BASE_OFFSET]).astype(int)
|
|
|
|
| 206 |
return np.array([]).reshape(0, 2)
|
| 207 |
|
| 208 |
def raster_svg(self, pixels: np.ndarray) -> Tuple[List[List[torch.Tensor]], List[int]]:
|
| 209 |
+
"""Convert pixels to SVG tensors following dataset.py logic."""
|
|
|
|
|
|
|
|
|
|
| 210 |
try:
|
| 211 |
if len(pixels) == 0:
|
| 212 |
return [[]], []
|
| 213 |
|
| 214 |
+
# Key step: subtract PIXEL_OFFSET
|
| 215 |
pixels = pixels - self.PIXEL_OFFSET
|
| 216 |
|
| 217 |
svg_tensors = []
|
|
|
|
| 284 |
path_tensor.append(cmd_tensor.tolist())
|
| 285 |
i += 2
|
| 286 |
|
| 287 |
+
# Color token: pix[0] >= COLOR_THRESHOLD
|
| 288 |
elif pix[0] >= self.COLOR_THRESHOLD:
|
| 289 |
if path_tensor:
|
| 290 |
svg_tensors.append(torch.tensor(path_tensor))
|
| 291 |
+
# Reverse transform: restore original color token
|
| 292 |
color_token = int(pix[0] + self.PIXEL_OFFSET - 1)
|
| 293 |
color_tensors.append(color_token)
|
| 294 |
path_tensor = []
|
|
|
|
| 300 |
print(f"Error at position {i}: {e}")
|
| 301 |
break
|
| 302 |
|
| 303 |
+
# Handle remaining path (without color)
|
| 304 |
if path_tensor:
|
| 305 |
svg_tensors.append(torch.tensor(path_tensor))
|
| 306 |
|
|
|
|
| 314 |
|
| 315 |
def apply_colors_to_svg(self, svg_tensors: List[torch.Tensor],
|
| 316 |
colors: Optional[List[int]]) -> SVG:
|
| 317 |
+
"""Apply colors and create final SVG."""
|
| 318 |
paths = []
|
| 319 |
|
| 320 |
if not svg_tensors:
|