Commit
·
feaa3d5
1
Parent(s):
2563d19
add intel xpu platform support
Browse filesSigned-off-by: Liu, Kaixuan <[email protected]>
- modeling_cogvlm.py +9 -3
- util.py +7 -1
- visual.py +2 -0
modeling_cogvlm.py
CHANGED
|
@@ -8,6 +8,7 @@ from torch import nn
|
|
| 8 |
from torch.nn import CrossEntropyLoss
|
| 9 |
from torchvision import transforms
|
| 10 |
from einops import rearrange
|
|
|
|
| 11 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
| 12 |
from transformers.utils.logging import get_logger
|
| 13 |
from transformers.activations import ACT2FN
|
|
@@ -723,9 +724,14 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
| 723 |
standardize_cache_format: bool = False,
|
| 724 |
) -> Dict[str, Any]:
|
| 725 |
# update past_key_values
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
model_kwargs[cache_name] = cache
|
| 730 |
|
| 731 |
if getattr(outputs, "state", None) is not None:
|
|
|
|
| 8 |
from torch.nn import CrossEntropyLoss
|
| 9 |
from torchvision import transforms
|
| 10 |
from einops import rearrange
|
| 11 |
+
import transformers
|
| 12 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
| 13 |
from transformers.utils.logging import get_logger
|
| 14 |
from transformers.activations import ACT2FN
|
|
|
|
| 724 |
standardize_cache_format: bool = False,
|
| 725 |
) -> Dict[str, Any]:
|
| 726 |
# update past_key_values
|
| 727 |
+
if transformers.__version__ >= "4.44.0":
|
| 728 |
+
cache_name, cache = self._extract_past_from_model_output(
|
| 729 |
+
outputs
|
| 730 |
+
)
|
| 731 |
+
else:
|
| 732 |
+
cache_name, cache = self._extract_past_from_model_output(
|
| 733 |
+
outputs, standardize_cache_format=standardize_cache_format
|
| 734 |
+
)
|
| 735 |
model_kwargs[cache_name] = cache
|
| 736 |
|
| 737 |
if getattr(outputs, "state", None) is not None:
|
util.py
CHANGED
|
@@ -7,6 +7,10 @@ import torch.nn.functional as F
|
|
| 7 |
import triton
|
| 8 |
import triton.language as tl
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
@triton.jit
|
| 12 |
def rotary_kernel(
|
|
@@ -197,7 +201,9 @@ def apply_rotary(
|
|
| 197 |
|
| 198 |
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
| 199 |
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
| 200 |
-
|
|
|
|
|
|
|
| 201 |
rotary_kernel[grid](
|
| 202 |
output, # data ptrs
|
| 203 |
x,
|
|
|
|
| 7 |
import triton
|
| 8 |
import triton.language as tl
|
| 9 |
|
| 10 |
+
device_contexts = {
|
| 11 |
+
'cuda': torch.cuda.device,
|
| 12 |
+
'xpu': torch.xpu.device
|
| 13 |
+
}
|
| 14 |
|
| 15 |
@triton.jit
|
| 16 |
def rotary_kernel(
|
|
|
|
| 201 |
|
| 202 |
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
| 203 |
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
| 204 |
+
device_type = x.device.type
|
| 205 |
+
assert device_type in device_contexts
|
| 206 |
+
with device_contexts[device_type](x.device.index):
|
| 207 |
rotary_kernel[grid](
|
| 208 |
output, # data ptrs
|
| 209 |
x,
|
visual.py
CHANGED
|
@@ -75,6 +75,8 @@ class Attention(nn.Module):
|
|
| 75 |
out = out.transpose(2, 1)
|
| 76 |
# breakpoint()
|
| 77 |
# output = self.dense(out.reshape(B, L, -1))
|
|
|
|
|
|
|
| 78 |
output = self.dense(out.view(B, L, -1))
|
| 79 |
output = self.output_dropout(output)
|
| 80 |
return output
|
|
|
|
| 75 |
out = out.transpose(2, 1)
|
| 76 |
# breakpoint()
|
| 77 |
# output = self.dense(out.reshape(B, L, -1))
|
| 78 |
+
if not out.is_contiguous():
|
| 79 |
+
out = out.contiguous()
|
| 80 |
output = self.dense(out.view(B, L, -1))
|
| 81 |
output = self.output_dropout(output)
|
| 82 |
return output
|