Remove hardcoded .cuda() calls to support single forward pass on CPU and ensure DeepSeekOCR model compatibility with transformers==4.52.4

#54

Problem Description

  • Hardcoded cuda conversion in forward method of DeepseekOCRModel blocking CPU inference for a single forward pass.
  • The model also has compatibility issues with LlamaAttention on transformers==4.52.4.

What's changed

  • Removed .cuda() from here for CPU execution(single forward pass)

  • Removed LlamaFlashAttention2 as it is not present in 4.52.4 and not used by model.

  • Position embeddings update:

    • In 4.46.3, position_embeddings parameter is optional but in 4.52.4 it is mandatory
    • In 4.46.3, position_embeddings is calculated inside LlamaAttention like this
    • In 4.52.4, it is calculated outside the LlamaAttention and passed to it.
    • To fix this, computed position_embeddings outside like this in here and passed to LlamaAttention
    • LlamaRotaryEmbedding was imported for this purpose.
  • Attention output update:

kamalrajkannanmcw changed pull request title from Removed hardcoded .cuda() calls to support CPU single forward pass and ensure DeepSeekOCR model compatibility with transformers==4.52.4 to Remove hardcoded .cuda() calls to support CPU single forward pass and ensure DeepSeekOCR model compatibility with transformers==4.52.4
kamalrajkannanmcw changed pull request title from Remove hardcoded .cuda() calls to support CPU single forward pass and ensure DeepSeekOCR model compatibility with transformers==4.52.4 to Remove hardcoded .cuda() calls to support single forward pass on CPU and ensure DeepSeekOCR model compatibility with transformers==4.52.4
kamalrajkannanmcw changed pull request status to open

It seems the error hasn't been resolved yet... the following error occurs. (flash_attention_2 is specified for the attention implementation.)

File "${HOME}/.cache/huggingface/modules/transformers_modules/deepseek_hyphen_ai/DeepSeek_hyphen_OCR/209ae73bb0f2e5b377f8eb18b09419014970ac34/modeling_deepseekv2.py", line 1253, in init
self.self_attn = ATTENTION_CLASSES[attn_implementation](
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
KeyError: 'mha_flash_attention_2'

The output of uv pip freeze is as follows:

$ uv pip freeze
accelerate==1.11.0
addict==2.4.0
aiohappyeyeballs==2.6.1
aiohttp==3.13.2
aiosignal==1.4.0
annotated-doc==0.0.3
annotated-types==0.7.0
anyio==4.11.0
astor==0.8.1
attrs==25.4.0
blake3==1.0.8
cachetools==6.2.1
cbor2==5.7.1
certifi==2025.10.5
cffi==2.0.0
charset-normalizer==3.4.4
click==8.2.1
cloudpickle==3.1.1
compressed-tensors==0.11.0
contourpy==1.3.3
cupy-cuda12x==13.6.0
cycler==0.12.1
depyf==0.19.0
dill==0.4.0
diskcache==5.6.3
distro==1.9.0
dnspython==2.8.0
easydict==1.13
einops==0.8.1
email-validator==2.3.0
fastapi==0.120.1
fastapi-cli==0.0.14
fastapi-cloud-cli==0.3.1
fastrlock==0.8.3
filelock==3.20.0
flash-attn==2.8.3
fonttools==4.60.1
frozendict==2.4.6
frozenlist==1.8.0
fsspec==2025.9.0
gguf==0.17.1
h11==0.16.0
hf-transfer==0.1.9
hf-xet==1.2.0
httpcore==1.0.9
httptools==0.7.1
httpx==0.28.1
huggingface-hub==0.36.0
idna==3.11
interegular==0.3.3
jinja2==3.1.6
jiter==0.11.1
jsonschema==4.25.1
jsonschema-specifications==2025.9.1
kiwisolver==1.4.9
lark==1.2.2
llguidance==0.7.30
llvmlite==0.44.0
lm-format-enforcer==0.11.3
markdown-it-py==4.0.0
markupsafe==3.0.3
matplotlib==3.10.7
mdurl==0.1.2
mistral-common==1.8.5
mpmath==1.3.0
msgpack==1.1.2
msgspec==0.19.0
multidict==6.7.0
networkx==3.5
ninja==1.13.0
numba==0.61.2
numpy==2.2.6
nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.3.83
nvidia-cufile-cu12==1.13.1.3
nvidia-curand-cu12==10.3.9.90
nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.27.3
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvtx-cu12==12.8.90
openai==2.6.1
openai-harmony==0.0.4
opencv-python-headless==4.12.0.88
outlines-core==0.2.11
packaging==25.0
partial-json-parser==0.2.1.1.post6
pillow==12.0.0
prometheus-client==0.23.1
prometheus-fastapi-instrumentator==7.1.0
propcache==0.4.1
protobuf==6.33.0
psutil==7.1.2
py-cpuinfo==9.0.0
pybase64==1.4.2
pycountry==24.6.1
pycparser==2.23
pydantic==2.12.3
pydantic-core==2.41.4
pydantic-extra-types==2.10.6
pygments==2.19.2
pyparsing==3.2.5
python-dateutil==2.9.0.post0
python-dotenv==1.2.1
python-json-logger==4.0.0
python-multipart==0.0.20
pyyaml==6.0.3
pyzmq==27.1.0
ray==2.51.0
referencing==0.37.0
regex==2025.10.23
requests==2.32.5
rich==14.2.0
rich-toolkit==0.15.1
rignore==0.7.1
rpds-py==0.28.0
safetensors==0.6.2
scipy==1.16.3
sentencepiece==0.2.1
sentry-sdk==2.42.1
setproctitle==1.3.7
setuptools==79.0.1
shellingham==1.5.4
six==1.17.0
sniffio==1.3.1
soundfile==0.13.1
soxr==1.0.0
starlette==0.49.1
sympy==1.14.0
tiktoken==0.12.0
tokenizers==0.22.1
torch==2.8.0
torchaudio==2.8.0
torchvision==0.23.0
tqdm==4.67.1
transformers==4.57.1
triton==3.4.0
typer==0.20.0
typing-extensions==4.15.0
typing-inspection==0.4.2
urllib3==2.5.0
uvicorn==0.38.0
uvloop==0.22.1
vllm==0.11.0
watchfiles==1.1.1
websockets==15.0.1
xformers==0.0.32.post1
xgrammar==0.1.25
yarl==1.22.0

It seems the error hasn't been resolved yet... the following error occurs. (flash_attention_2 is specified for the attention implementation.)

File "${HOME}/.cache/huggingface/modules/transformers_modules/deepseek_hyphen_ai/DeepSeek_hyphen_OCR/209ae73bb0f2e5b377f8eb18b09419014970ac34/modeling_deepseekv2.py", line 1253, in init
self.self_attn = ATTENTION_CLASSES[attn_implementation](

KeyError: 'mha_flash_attention_2'

Root Cause

  • The error occurs because the attention implementation 'mha_flash_attention_2' is mapped to LlamaFlashAttention2 here, which does not exist in the transformers.models.llama.modeling_llama module.
  • This class is missing in both versions - 4.57.1 (yours) & 4.52.4 (my target version).
  • As a result, the key lookup in ATTENTION_CLASSES fails, causing the KeyError: 'mha_flash_attention_2'
Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment