OmniSVG commited on
Commit
aac856b
·
verified ·
1 Parent(s): 90e2bbe

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +78 -44
tokenizer.py CHANGED
@@ -8,38 +8,82 @@ from deepsvg.svglib.geom import Bbox
8
 
9
 
10
  class SVGTokenizer:
11
- """SVG tokenizer - config.yaml加载所有配置,避免硬编码"""
12
 
13
- def __init__(self, config_path: str = "./config.yaml"):
 
 
 
 
 
 
 
14
  with open(config_path, 'r') as f:
15
  self.config = yaml.safe_load(f)
16
 
 
 
 
 
 
17
  self._load_config()
18
  self.pixel2xy = self._create_pixel2xy_mapping()
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def _load_config(self):
21
- """从配置文件加载所有常量"""
22
- # ========== Token相关配置 ==========
 
 
 
 
 
23
  tokens_cfg = self.config['tokens']
24
  self.NUM_SVG_END = tokens_cfg['svg_end']
25
- self.BASE_OFFSET = tokens_cfg['base_offset']
26
- self.NUM_MASK_AND_EOM = tokens_cfg['num_mask_and_eom']
27
  self.NUM_END_TOKEN = tokens_cfg['num_end_token']
28
 
29
- # ========== 坐标相关配置 ==========
 
 
 
 
 
30
  coords_cfg = self.config['coordinates']
31
  self.BBOX = coords_cfg['bbox']
32
- self.PIX_PAD = coords_cfg['pix_pad_offset']
33
- self.COORD_PAD = coords_cfg['coord_pad_offset']
34
 
35
- # ========== 颜色相关配置 ==========
36
  colors_cfg = self.config['colors']
37
  self.COLOR_TOKEN_START_RAW = colors_cfg['color_token_start']
38
- self.COLOR_START_OFFSET = colors_cfg['color_start_offset']
39
- self.COLOR_END_OFFSET = colors_cfg['color_end_offset']
40
  self.MAX_COLOR_TOKENS = colors_cfg['max_color_tokens']
41
 
42
- # ========== SVG命令值(用于 raster_svg 中的判断)==========
 
 
 
 
43
  commands_cfg = self.config['svg_commands']
44
  self.CMD_MOVE = commands_cfg['move']
45
  self.CMD_LINE = commands_cfg['line']
@@ -47,41 +91,37 @@ class SVGTokenizer:
47
  self.CMD_ARC = commands_cfg['arc']
48
  self.CMD_CLOSE = commands_cfg['close']
49
 
50
- # ========== 模型相关配置 ==========
51
  model_cfg = self.config['model']
52
  self.BOS_TOKEN_ID = model_cfg['bos_token_id']
53
  self.EOS_TOKEN_ID = model_cfg['eos_token_id']
54
  self.PAD_TOKEN_ID = model_cfg['pad_token_id']
55
 
56
- # ========== Arc参数配置 ==========
57
  arc_cfg = self.config.get('arc', {})
58
  self.ARC_PARAM_OFFSET = arc_cfg.get('param_offset', 44500)
59
  self.ARC_PARAM_RANGE = arc_cfg.get('param_range', 100)
60
  self.ARC_PARAM_START = self.ARC_PARAM_OFFSET + self.BASE_OFFSET
61
 
62
- # ========== 派生常量计算 ==========
63
- # PIXEL_OFFSET: 从配置推导
64
- # 命令token存储值 - BASE_OFFSET - PIXEL_OFFSET = CMD_MOVE
65
- # (NUM_MASK_AND_EOM + NUM_SVG_END) - BASE_OFFSET - PIXEL_OFFSET = CMD_MOVE
66
  self.PIXEL_OFFSET = (self.NUM_MASK_AND_EOM - self.BASE_OFFSET +
67
  self.NUM_SVG_END - self.CMD_MOVE)
68
 
69
- # 命令token的实际范围
70
  self.CMD_TOKEN_START = self.NUM_MASK_AND_EOM + self.NUM_SVG_END
71
  self.CMD_TOKEN_END = self.PIX_PAD + self.NUM_SVG_END
72
 
73
- # 坐标token起始
74
  self.COORD_TOKEN_START = self.PIX_PAD + self.NUM_SVG_END
75
 
76
- # 颜色token边界(坐标与颜色的分界)
77
  self.COLOR_COORD_BOUNDARY = self.COLOR_TOKEN_START_RAW + 1 + self.BASE_OFFSET
78
 
79
- # 颜色阈值(用于 raster_svg 中判断)
80
- # 减去 PIXEL_OFFSET 后的颜色token下限
81
  self.COLOR_THRESHOLD = self.COLOR_TOKEN_START_RAW - self.PIXEL_OFFSET + 1
82
 
83
  def _create_pixel2xy_mapping(self) -> Dict[int, np.ndarray]:
84
- """按照 dataset.py 逻辑创建 pixel xy 的映射"""
85
  pixel2xy = {}
86
  x = np.linspace(0, self.BBOX - 1, self.BBOX)
87
  y = np.linspace(0, self.BBOX - 1, self.BBOX)
@@ -89,13 +129,12 @@ class SVGTokenizer:
89
  xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int)
90
 
91
  for pixel, xy in enumerate(xy_grid):
92
- # xy + COORD_PAD + NUM_SVG_END
93
  pixel2xy[pixel] = xy + self.COORD_PAD + self.NUM_SVG_END
94
 
95
  return pixel2xy
96
 
97
  def token_to_color(self, color_token: int) -> str:
98
- """按照 dataset.py 的 token_to_color 逻辑"""
99
  try:
100
  if color_token == self.COLOR_TOKEN_START_RAW:
101
  return "none"
@@ -123,37 +162,35 @@ class SVGTokenizer:
123
  return "#808080"
124
 
125
  def process_generated_tokens(self, output_ids: torch.Tensor) -> np.ndarray:
126
- """
127
- 按照 dataset.py 的 __getitem__ 逻辑处理 tokens
128
- """
129
- # 移除 bos/eos
130
  generated_pixels = output_ids[:, 1:-1].cpu().numpy().flatten()
131
 
132
  sample_xys = []
133
 
134
  for pixel in generated_pixels:
135
  try:
136
- # 1. 命令tokens: CMD_TOKEN_START <= pixel < CMD_TOKEN_END
137
  if self.CMD_TOKEN_START <= pixel < self.CMD_TOKEN_END:
138
  xy = np.array([pixel - self.BASE_OFFSET,
139
  pixel - self.BASE_OFFSET]).astype(int)
140
  sample_xys.append(xy)
141
 
142
- # 2. 坐标tokens: COORD_TOKEN_START <= pixel < COLOR_COORD_BOUNDARY
143
  elif self.COORD_TOKEN_START <= pixel < self.COLOR_COORD_BOUNDARY:
144
  pixel_index = pixel - self.COORD_TOKEN_START
145
  if pixel_index in self.pixel2xy:
146
  xy = self.pixel2xy[pixel_index] - self.BASE_OFFSET
147
  sample_xys.append(xy)
148
 
149
- # 3. Arc参数: ARC_PARAM_START + 1 <= pixel < ARC_PARAM_START + 1 + ARC_PARAM_RANGE
150
  elif (self.ARC_PARAM_START + 1 <= pixel <
151
  self.ARC_PARAM_START + 1 + self.ARC_PARAM_RANGE):
152
  value = pixel - self.ARC_PARAM_START - 1
153
  xy = np.array([value, value]).astype(int)
154
  sample_xys.append(xy)
155
 
156
- # 4. 颜色tokens: COLOR_COORD_BOUNDARY <= pixel < ARC_PARAM_START
157
  elif self.COLOR_COORD_BOUNDARY <= pixel < self.ARC_PARAM_START:
158
  xy = np.array([pixel - self.BASE_OFFSET,
159
  pixel - self.BASE_OFFSET]).astype(int)
@@ -169,15 +206,12 @@ class SVGTokenizer:
169
  return np.array([]).reshape(0, 2)
170
 
171
  def raster_svg(self, pixels: np.ndarray) -> Tuple[List[List[torch.Tensor]], List[int]]:
172
- """
173
- 按照 dataset.py 的 raster_svg 逻辑
174
- 关键:pixels -= PIXEL_OFFSET 是核心转换步骤
175
- """
176
  try:
177
  if len(pixels) == 0:
178
  return [[]], []
179
 
180
- # ========== 关键步骤:减去 PIXEL_OFFSET ==========
181
  pixels = pixels - self.PIXEL_OFFSET
182
 
183
  svg_tensors = []
@@ -250,11 +284,11 @@ class SVGTokenizer:
250
  path_tensor.append(cmd_tensor.tolist())
251
  i += 2
252
 
253
- # 颜色token: pix[0] >= COLOR_THRESHOLD
254
  elif pix[0] >= self.COLOR_THRESHOLD:
255
  if path_tensor:
256
  svg_tensors.append(torch.tensor(path_tensor))
257
- # 逆转换:还原原始颜色token
258
  color_token = int(pix[0] + self.PIXEL_OFFSET - 1)
259
  color_tensors.append(color_token)
260
  path_tensor = []
@@ -266,7 +300,7 @@ class SVGTokenizer:
266
  print(f"Error at position {i}: {e}")
267
  break
268
 
269
- # 处理剩余路径(无颜色)
270
  if path_tensor:
271
  svg_tensors.append(torch.tensor(path_tensor))
272
 
@@ -280,7 +314,7 @@ class SVGTokenizer:
280
 
281
  def apply_colors_to_svg(self, svg_tensors: List[torch.Tensor],
282
  colors: Optional[List[int]]) -> SVG:
283
- """应用颜色并创建最终SVG"""
284
  paths = []
285
 
286
  if not svg_tensors:
 
8
 
9
 
10
  class SVGTokenizer:
11
+ """SVG tokenizer - supports both 8B and 4B models via config.yaml"""
12
 
13
+ def __init__(self, config_path: str = "./config.yaml", model_size: str = None):
14
+ """
15
+ Initialize SVGTokenizer.
16
+
17
+ Args:
18
+ config_path: Path to config.yaml
19
+ model_size: Model size ("8B" or "4B"). If None, uses default from config.
20
+ """
21
  with open(config_path, 'r') as f:
22
  self.config = yaml.safe_load(f)
23
 
24
+ # Determine model size
25
+ self.model_size = model_size or self.config.get('default_model_size', '8B')
26
+ if self.model_size not in self.config.get('models', {}):
27
+ raise ValueError(f"Invalid model_size: {self.model_size}. Must be one of: {list(self.config.get('models', {}).keys())}")
28
+
29
  self._load_config()
30
  self.pixel2xy = self._create_pixel2xy_mapping()
31
 
32
+ def _get_model_specific_config(self, *keys):
33
+ """Get model-specific config value, with fallback to shared config."""
34
+ model_cfg = self.config.get('models', {}).get(self.model_size, {})
35
+
36
+ # Navigate through nested keys in model-specific config
37
+ value = model_cfg
38
+ for key in keys:
39
+ if isinstance(value, dict) and key in value:
40
+ value = value[key]
41
+ else:
42
+ value = None
43
+ break
44
+
45
+ # If not found in model-specific, try shared config
46
+ if value is None:
47
+ value = self.config
48
+ for key in keys:
49
+ if isinstance(value, dict) and key in value:
50
+ value = value[key]
51
+ else:
52
+ return None
53
+
54
+ return value
55
+
56
  def _load_config(self):
57
+ """Load all constants from configuration file with model-specific overrides."""
58
+ # ========== Token-related configs ==========
59
+ # Model-specific tokens
60
+ self.NUM_MASK_AND_EOM = self._get_model_specific_config('tokens', 'num_mask_and_eom')
61
+ self.BASE_OFFSET = self._get_model_specific_config('tokens', 'base_offset')
62
+
63
+ # Shared tokens
64
  tokens_cfg = self.config['tokens']
65
  self.NUM_SVG_END = tokens_cfg['svg_end']
 
 
66
  self.NUM_END_TOKEN = tokens_cfg['num_end_token']
67
 
68
+ # ========== Coordinate-related configs ==========
69
+ # Model-specific coordinates
70
+ self.PIX_PAD = self._get_model_specific_config('coordinates', 'pix_pad_offset')
71
+ self.COORD_PAD = self._get_model_specific_config('coordinates', 'coord_pad_offset')
72
+
73
+ # Shared coordinates
74
  coords_cfg = self.config['coordinates']
75
  self.BBOX = coords_cfg['bbox']
 
 
76
 
77
+ # ========== Color-related configs ==========
78
  colors_cfg = self.config['colors']
79
  self.COLOR_TOKEN_START_RAW = colors_cfg['color_token_start']
 
 
80
  self.MAX_COLOR_TOKENS = colors_cfg['max_color_tokens']
81
 
82
+ # Model-specific colors
83
+ self.COLOR_START_OFFSET = self._get_model_specific_config('colors', 'color_start_offset')
84
+ self.COLOR_END_OFFSET = self._get_model_specific_config('colors', 'color_end_offset')
85
+
86
+ # ========== SVG command values ==========
87
  commands_cfg = self.config['svg_commands']
88
  self.CMD_MOVE = commands_cfg['move']
89
  self.CMD_LINE = commands_cfg['line']
 
91
  self.CMD_ARC = commands_cfg['arc']
92
  self.CMD_CLOSE = commands_cfg['close']
93
 
94
+ # ========== Model-related configs ==========
95
  model_cfg = self.config['model']
96
  self.BOS_TOKEN_ID = model_cfg['bos_token_id']
97
  self.EOS_TOKEN_ID = model_cfg['eos_token_id']
98
  self.PAD_TOKEN_ID = model_cfg['pad_token_id']
99
 
100
+ # ========== Arc parameter configs ==========
101
  arc_cfg = self.config.get('arc', {})
102
  self.ARC_PARAM_OFFSET = arc_cfg.get('param_offset', 44500)
103
  self.ARC_PARAM_RANGE = arc_cfg.get('param_range', 100)
104
  self.ARC_PARAM_START = self.ARC_PARAM_OFFSET + self.BASE_OFFSET
105
 
106
+ # ========== Derived constants ==========
 
 
 
107
  self.PIXEL_OFFSET = (self.NUM_MASK_AND_EOM - self.BASE_OFFSET +
108
  self.NUM_SVG_END - self.CMD_MOVE)
109
 
110
+ # Command token range
111
  self.CMD_TOKEN_START = self.NUM_MASK_AND_EOM + self.NUM_SVG_END
112
  self.CMD_TOKEN_END = self.PIX_PAD + self.NUM_SVG_END
113
 
114
+ # Coordinate token start
115
  self.COORD_TOKEN_START = self.PIX_PAD + self.NUM_SVG_END
116
 
117
+ # Color-coordinate boundary
118
  self.COLOR_COORD_BOUNDARY = self.COLOR_TOKEN_START_RAW + 1 + self.BASE_OFFSET
119
 
120
+ # Color threshold for raster_svg
 
121
  self.COLOR_THRESHOLD = self.COLOR_TOKEN_START_RAW - self.PIXEL_OFFSET + 1
122
 
123
  def _create_pixel2xy_mapping(self) -> Dict[int, np.ndarray]:
124
+ """Create pixel to xy mapping following dataset.py logic."""
125
  pixel2xy = {}
126
  x = np.linspace(0, self.BBOX - 1, self.BBOX)
127
  y = np.linspace(0, self.BBOX - 1, self.BBOX)
 
129
  xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int)
130
 
131
  for pixel, xy in enumerate(xy_grid):
 
132
  pixel2xy[pixel] = xy + self.COORD_PAD + self.NUM_SVG_END
133
 
134
  return pixel2xy
135
 
136
  def token_to_color(self, color_token: int) -> str:
137
+ """Convert token to color following dataset.py logic."""
138
  try:
139
  if color_token == self.COLOR_TOKEN_START_RAW:
140
  return "none"
 
162
  return "#808080"
163
 
164
  def process_generated_tokens(self, output_ids: torch.Tensor) -> np.ndarray:
165
+ """Process generated tokens following dataset.py logic."""
166
+ # Remove bos/eos
 
 
167
  generated_pixels = output_ids[:, 1:-1].cpu().numpy().flatten()
168
 
169
  sample_xys = []
170
 
171
  for pixel in generated_pixels:
172
  try:
173
+ # 1. Command tokens: CMD_TOKEN_START <= pixel < CMD_TOKEN_END
174
  if self.CMD_TOKEN_START <= pixel < self.CMD_TOKEN_END:
175
  xy = np.array([pixel - self.BASE_OFFSET,
176
  pixel - self.BASE_OFFSET]).astype(int)
177
  sample_xys.append(xy)
178
 
179
+ # 2. Coordinate tokens: COORD_TOKEN_START <= pixel < COLOR_COORD_BOUNDARY
180
  elif self.COORD_TOKEN_START <= pixel < self.COLOR_COORD_BOUNDARY:
181
  pixel_index = pixel - self.COORD_TOKEN_START
182
  if pixel_index in self.pixel2xy:
183
  xy = self.pixel2xy[pixel_index] - self.BASE_OFFSET
184
  sample_xys.append(xy)
185
 
186
+ # 3. Arc parameters: ARC_PARAM_START + 1 <= pixel < ARC_PARAM_START + 1 + ARC_PARAM_RANGE
187
  elif (self.ARC_PARAM_START + 1 <= pixel <
188
  self.ARC_PARAM_START + 1 + self.ARC_PARAM_RANGE):
189
  value = pixel - self.ARC_PARAM_START - 1
190
  xy = np.array([value, value]).astype(int)
191
  sample_xys.append(xy)
192
 
193
+ # 4. Color tokens: COLOR_COORD_BOUNDARY <= pixel < ARC_PARAM_START
194
  elif self.COLOR_COORD_BOUNDARY <= pixel < self.ARC_PARAM_START:
195
  xy = np.array([pixel - self.BASE_OFFSET,
196
  pixel - self.BASE_OFFSET]).astype(int)
 
206
  return np.array([]).reshape(0, 2)
207
 
208
  def raster_svg(self, pixels: np.ndarray) -> Tuple[List[List[torch.Tensor]], List[int]]:
209
+ """Convert pixels to SVG tensors following dataset.py logic."""
 
 
 
210
  try:
211
  if len(pixels) == 0:
212
  return [[]], []
213
 
214
+ # Key step: subtract PIXEL_OFFSET
215
  pixels = pixels - self.PIXEL_OFFSET
216
 
217
  svg_tensors = []
 
284
  path_tensor.append(cmd_tensor.tolist())
285
  i += 2
286
 
287
+ # Color token: pix[0] >= COLOR_THRESHOLD
288
  elif pix[0] >= self.COLOR_THRESHOLD:
289
  if path_tensor:
290
  svg_tensors.append(torch.tensor(path_tensor))
291
+ # Reverse transform: restore original color token
292
  color_token = int(pix[0] + self.PIXEL_OFFSET - 1)
293
  color_tensors.append(color_token)
294
  path_tensor = []
 
300
  print(f"Error at position {i}: {e}")
301
  break
302
 
303
+ # Handle remaining path (without color)
304
  if path_tensor:
305
  svg_tensors.append(torch.tensor(path_tensor))
306
 
 
314
 
315
  def apply_colors_to_svg(self, svg_tensors: List[torch.Tensor],
316
  colors: Optional[List[int]]) -> SVG:
317
+ """Apply colors and create final SVG."""
318
  paths = []
319
 
320
  if not svg_tensors: