Spaces:
Build error
Build error
Fix device error when using cuda (#4)
Browse files- Fix device error when using cuda (030c843a1bb758298a3f0bc6f2564a26aaff878e)
Co-authored-by: Ma Jinyu <[email protected]>
- models/tag2text.py +2 -3
models/tag2text.py
CHANGED
|
@@ -152,8 +152,7 @@ class RAM(nn.Module):
|
|
| 152 |
self.class_threshold[key] = value
|
| 153 |
|
| 154 |
def load_tag_list(self, tag_list_file):
|
| 155 |
-
with open(tag_list_file, 'r', encoding="
|
| 156 |
-
# with open(tag_list_file, 'r') as f:
|
| 157 |
tag_list = f.read().splitlines()
|
| 158 |
tag_list = np.array(tag_list)
|
| 159 |
return tag_list
|
|
@@ -362,7 +361,7 @@ class Tag2Text_Caption(nn.Module):
|
|
| 362 |
logits = self.fc(tagging_embed[0])
|
| 363 |
|
| 364 |
targets = torch.where(
|
| 365 |
-
torch.sigmoid(logits) > self.class_threshold,
|
| 366 |
torch.tensor(1.0).to(image.device),
|
| 367 |
torch.zeros(self.num_class).to(image.device))
|
| 368 |
|
|
|
|
| 152 |
self.class_threshold[key] = value
|
| 153 |
|
| 154 |
def load_tag_list(self, tag_list_file):
|
| 155 |
+
with open(tag_list_file, 'r', encoding="utf8") as f:
|
|
|
|
| 156 |
tag_list = f.read().splitlines()
|
| 157 |
tag_list = np.array(tag_list)
|
| 158 |
return tag_list
|
|
|
|
| 361 |
logits = self.fc(tagging_embed[0])
|
| 362 |
|
| 363 |
targets = torch.where(
|
| 364 |
+
torch.sigmoid(logits) > self.class_threshold.to(image.device),
|
| 365 |
torch.tensor(1.0).to(image.device),
|
| 366 |
torch.zeros(self.num_class).to(image.device))
|
| 367 |
|