OmniSVG commited on
Commit
dbc468c
·
verified ·
1 Parent(s): 64db0c1

Update decoder.py

Browse files
Files changed (1) hide show
  1. decoder.py +2 -2
decoder.py CHANGED
@@ -26,8 +26,8 @@ class SketchDecoder(nn.Module):
26
  self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
27
  "Qwen/Qwen2.5-VL-3B-Instruct",
28
  config=config,
29
- #torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
30
- #device_map ="cuda",
31
  ignore_mismatched_sizes=True
32
  )
33
 
 
26
  self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
27
  "Qwen/Qwen2.5-VL-3B-Instruct",
28
  config=config,
29
+ torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
30
+ device_map ="cuda",
31
  ignore_mismatched_sizes=True
32
  )
33